diff --git a/tests/kernels/moe/deepep_utils.py b/tests/kernels/moe/deepep_utils.py index 117f1babdf6..0c8895b2449 100644 --- a/tests/kernels/moe/deepep_utils.py +++ b/tests/kernels/moe/deepep_utils.py @@ -138,9 +138,7 @@ def make_deepep_ht_a2a(pg: ProcessGroup, rank=pgi.rank, dp_size=dp_size, rank_expert_offset=pgi.rank * - ht_args.num_local_experts, - quant_dtype=q_dtype, - block_shape=block_shape) + ht_args.num_local_experts) def make_deepep_ll_a2a(pg: ProcessGroup, @@ -168,8 +166,6 @@ def make_deepep_ll_a2a(pg: ProcessGroup, world_size=pgi.world_size, dp_size=dp_size, max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank, - quant_dtype=q_dtype, - block_shape=block_shape, use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch, ) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index b0e0feab468..67fa66686c2 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -2,18 +2,36 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass +from typing import Optional import pytest import torch import triton.language as tl +from tests.kernels.moe.utils import (batched_moe, + make_quantized_test_activations, + make_test_weights, triton_moe) +from tests.kernels.quant_utils import native_batched_masked_quant_matmul +from tests.kernels.utils import torch_experts +from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( invoke_moe_batched_triton_kernel) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.platforms import current_platform + +NUM_EXPERTS = [8, 64] +TOP_KS = [1, 2, 6] + +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 @dataclass class BatchedMMConfig: - dtype: torch.dtype + in_dtype: torch.dtype + quant_dtype: Optional[torch.dtype] + out_dtype: torch.dtype num_experts: int max_tokens_per_expert: int K: int @@ -32,79 +50,126 @@ def make_tensors(config: BatchedMMConfig): A = torch.randn( (config.num_experts, config.max_tokens_per_expert, config.K), device="cuda", - dtype=config.dtype) / 10 + dtype=config.in_dtype) / 10 B = torch.randn((config.num_experts, config.N, config.K), device="cuda", - dtype=config.dtype) + dtype=config.in_dtype) C = torch.zeros( (config.num_experts, config.max_tokens_per_expert, config.N), device="cuda", - dtype=config.dtype) + dtype=config.out_dtype) + num_expert_tokens = torch.randint(low=0, high=config.max_tokens_per_expert, size=(config.num_experts, ), device="cuda", dtype=torch.int32) - return BatchedMMTensors(A, B, C, num_expert_tokens) - -def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - num_expert_tokens: torch.Tensor) -> torch.Tensor: - - num_expert_tokens_cpu = num_expert_tokens.clone() - num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") - num_experts = num_expert_tokens.size(0) - - for e in range(num_experts): - num_tokens = num_expert_tokens_cpu[e] - C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) - - return C + return BatchedMMTensors(A, B, C, num_expert_tokens) -@pytest.mark.parametrize("num_experts", [16, 32]) +@pytest.mark.parametrize("num_experts", [8, 16, 32]) @pytest.mark.parametrize("max_tokens_per_expert", [32, 64, 128, 192, 224, 256, 512]) @pytest.mark.parametrize("K", [128, 256, 1024]) @pytest.mark.parametrize("N", [128, 256, 512, 1024]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("block_shape", [None]) +@pytest.mark.parametrize("per_act_token_quant", [False]) def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, - N: int, dtype: torch.dtype): + N: int, dtype: torch.dtype, + block_shape: Optional[list[int]], + per_act_token_quant: bool): + current_platform.seed_everything(7) - config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N) - tensors = BatchedMMTensors.make_tensors(config) + use_fp8_w8a8 = dtype == torch.float8_e4m3fn - test_output = tensors.C - ref_output = test_output.clone() + if block_shape is not None and not use_fp8_w8a8: + pytest.skip("Don't test blocking for non-quantized types.") + + if dtype.itemsize == 1: + act_dtype = torch.bfloat16 + quant_dtype = dtype + else: + act_dtype = dtype + quant_dtype = None + + #print(f"TYPES {dtype}, {act_dtype}, {quant_dtype}") + + num_expert_tokens = torch.randint(low=0, + high=max_tokens_per_expert, + size=(num_experts, ), + device="cuda", + dtype=torch.int32) + + A, A_q, A_scale = make_quantized_test_activations( + num_experts, + max_tokens_per_expert, + K, + in_dtype=act_dtype, + quant_dtype=quant_dtype, + block_shape=block_shape, + per_act_token_quant=per_act_token_quant) + + B, B_q, B_scale, _, _, _ = make_test_weights( + num_experts, + N // 2, + K, + in_dtype=act_dtype, + quant_dtype=quant_dtype, + block_shape=block_shape, + ) + + out_shape = (num_experts, max_tokens_per_expert, N) + test_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda") + ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda") + q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda") compute_tl_dtype = { torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, torch.float32: tl.float32 }[test_output.dtype] + + assert A_q.dtype == B_q.dtype + invoke_moe_batched_triton_kernel( - tensors.A, - tensors.B, + A_q, + B_q, test_output, - tensors.num_expert_tokens, + num_expert_tokens, compute_tl_dtype, # Quantization data - None, - None, + A_scale, + B_scale, None, # Quantization schemes - False, + use_fp8_w8a8, False, False, config={ "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 16 - }) + "BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32 + }, + block_shape=block_shape, + ) - ref_output = ref_impl(tensors.A, tensors.B, ref_output, - tensors.num_expert_tokens) + ref_output = native_batched_masked_quant_matmul( + A, + B, + ref_output, + num_expert_tokens, + None, + None, + None, + ) + + q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output, + num_expert_tokens, + A_scale, B_scale, + block_shape) rtol, atol = { torch.float16: (6e-2, 6e-2), @@ -112,4 +177,83 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, torch.float32: (1e-2, 1e-2), }[test_output.dtype] - torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol) + torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol) + torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("m", [1, 32, 45, 64, 222]) +@pytest.mark.parametrize("n", [128, 512, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 512, 1024, 2048]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("per_act_token_quant", [False]) +@pytest.mark.parametrize("block_shape", [None]) +def test_fused_moe_batched_experts( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + per_act_token_quant: bool, + block_shape: Optional[list[int]], +): + current_platform.seed_everything(7) + + use_fp8_w8a8 = dtype == torch.float8_e4m3fn + + if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None: + pytest.skip("Skip quantization test for non-quantized type") + + if per_act_token_quant and block_shape is not None or topk > e: + pytest.skip("Skip illegal quantization test") + + a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 + score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) + + if dtype.itemsize == 1: + act_dtype = torch.bfloat16 + quant_dtype = dtype + else: + act_dtype = dtype + quant_dtype = None + + _, w1, w1_s, _, w2, w2_s = make_test_weights(e, + n, + k, + block_shape=block_shape, + in_dtype=act_dtype, + quant_dtype=quant_dtype) + + torch.set_printoptions(profile="full") + + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) + batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s, + w2_s, quant_dtype, per_act_token_quant, + block_shape) + baseline_output = torch_experts( + a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape) + triton_output = triton_moe(a, w1, w2, topk_weight, topk_ids, w1_s, + w2_s, quant_dtype, per_act_token_quant, + block_shape) + + torch.testing.assert_close(triton_output, + baseline_output, + atol=2e-2, + rtol=2e-2) + + torch.testing.assert_close(triton_output, + batched_output, + atol=2e-2, + rtol=2e-2) diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py new file mode 100644 index 00000000000..8e8a2229206 --- /dev/null +++ b/tests/kernels/moe/test_block_fp8.py @@ -0,0 +1,372 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from https://github.com/sgl-project/sglang/pull/2575 +import itertools + +import pytest +import torch + +from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, + native_w8a8_block_matmul, + per_block_cast_to_fp8) +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + _valid_deep_gemm_shape, deep_gemm_moe_fp8) +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, modular_triton_fused_moe) +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) +from vllm.platforms import current_platform + +dg_available = False +try: + import deep_gemm + dg_available = True +except ImportError: + pass + +if current_platform.get_device_capability() < (9, 0): + pytest.skip("FP8 Triton requires CUDA 9.0 or higher", + allow_module_level=True) + +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + +# Test configurations +DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] +NUM_TOKENS = [7, 2050] +D = [512, 4096, 5120, 13824] +GROUP_SIZE = [64, 128, 512] +# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 +# and its hidden size is 7168. +M = [1, 2, 83, 128, 2048, 40000] +M_dg = [128, 192, 1335, 2048] +N = [128, 256, 1024, 4608] # [13824] +K = [256, 512, 7168] # [13824] +BLOCK_SIZE = [[128, 128]] +E = [2, 8, 16, 24] # [128, 256] +TOP_KS = [1, 2, 6] +OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] +SEEDS = [0] + + +def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): + """Fused moe with block-wise quantization using native torch.""" + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + _, block_k = block_shape[0], block_shape[1] + a_q, a_s = native_per_token_group_quant_fp8(a, block_k) + a_q = a_q.to(torch.float32) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + inter_out = native_w8a8_block_matmul(a_q[mask], + w1[i], + a_s[mask], + w1_s[i], + block_shape, + output_dtype=a.dtype) + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = native_per_token_group_quant_fp8( + act_out, block_k) + out[mask] = native_w8a8_block_matmul(act_out_q, + w2[i], + act_out_s, + w2_s[i], + block_shape, + output_dtype=a.dtype) + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + +# Skip all tests if CUDA is not available +pytest.importorskip("torch.cuda") + + +@pytest.fixture(autouse=True) +def setup_cuda(): + torch.set_default_device("cuda") + + +@pytest.mark.parametrize( + "M,N,K,E,topk,block_size,dtype,seed", + itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) +@torch.inference_mode() +def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, + monkeypatch): + if topk > E: + pytest.skip(f"Skipping test; topk={topk} > E={E}") + + torch.manual_seed(seed) + + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") + + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_bf16 = (torch.rand( + (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + del w1_bf16 + + w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + del w2_bf16 + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + + w1_s = torch.rand( + (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale + w2_s = torch.rand( + (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale + + score = torch.randn((M, E), dtype=dtype) + + m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_act_token_quant=False, + block_shape=block_size) + + # Set the context to avoid lots of warning spam. + with set_current_vllm_config(vllm_config): + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) + + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) + m_out = m_fused_moe(a, + w1, + w2, + topk_weights, + topk_ids, + global_num_experts=E, + w1_scale=w1_s, + w2_scale=w2_s) + + #print(f"{out.sum()=}") + #print(f"{ref_out.sum()=}") + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.03 + + rel_diff = (torch.mean( + torch.abs(m_out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.03 + + +def fp8_perm(m, idx): + if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: + return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) + else: + return m[idx, ...] + + +def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): + M, K = a.shape + + sorted_token_ids, m_indices, num_pad = moe_align_block_size( + topk_ids, block_m, num_groups, None, pad_sorted_ids=True) + + num_tokens = topk * M + + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) + m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) + inv_perm = torch.argsort(sorted_token_ids)[:M * topk] + + a = fp8_perm(a, sorted_token_ids // topk) + if a_s is not None: + a_s = a_s[sorted_token_ids // topk] + + return a, a_s, m_indices, inv_perm + + +def _moe_unpermute(out, inv_perm, topk, K, topk_weight): + M = topk_weight.shape[0] + out = out[inv_perm, ...] + tmp_out = out.view(-1, topk, K) + return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + + +def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, + block_shape): + """Fused moe with block-wise quantization using DeepGemm grouped gemm.""" + num_groups = w1.shape[0] + M, K = a.shape + N = w2.shape[-1] + + topk_weight, topk_ids, token_expert_indices = fused_topk( + a, score.float(), topk, False) + + block_m = deep_gemm.get_m_alignment_for_contiguous_layout() + + _, block_k = block_shape[0], block_shape[1] + + a_q, a_s = per_token_group_quant_fp8(a, block_m) + + a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids, + num_groups, topk, block_m) + + inter_out = torch.zeros((a_q.shape[0], N * 2), + dtype=torch.bfloat16, + device=a.device) + + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), + inter_out, m_indices) + + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) + + out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device) + + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (act_out_q, act_out_s), (w2, w2_s), out, m_indices) + + final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight) + + return final_out + + +@pytest.mark.parametrize("M,N,K,E,topk,seed", + itertools.product(M_dg, N, K, E, TOP_KS, SEEDS)) +@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") +@torch.inference_mode() +def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, + monkeypatch): + if topk > E: + pytest.skip(f"Skipping test: topk={topk} > E={E}") + + if not _valid_deep_gemm_shape(M, N, K): + pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") + + chunk_size = 1024 + + torch.manual_seed(seed) + + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size)) + + block_m = deep_gemm.get_m_alignment_for_contiguous_layout() + block_size = [block_m, block_m] + dtype = torch.bfloat16 + + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * + fp8_max).clamp(min=fp8_min, max=fp8_max) + + w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * + fp8_max).clamp(min=fp8_min, max=fp8_max) + + score = torch.randn((M, E), dtype=dtype) + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = ((2 * N) + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w2 = (N + block_k - 1) // block_k + + w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) + w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) + + w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) + w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + + w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() + w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() + + assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128) + assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] + + for i in range(E): + w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) + w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + + # Note: for now use_compile will error out if the problem size is + # large enough to trigger chunking. I'm leaving the flag and + # setup code in case we are able to revisit this later. + use_compile = False + + use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024 + and current_platform.is_cuda_alike()) + + # Set the context to avoid lots of warning spam. + with set_current_vllm_config(vllm_config): + if M >= 128: + ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, + score, topk, block_size) + else: + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, + topk, block_size) + + topk_weights, topk_ids, token_expert_indices = fused_topk( + a, score.float(), topk, False) + + if use_compile: + deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8, + backend="inductor", + fullgraph=True) + torch._dynamo.mark_dynamic(a, 0) + torch._dynamo.mark_dynamic(topk_weights, 0) + torch._dynamo.mark_dynamic(topk_ids, 0) + else: + deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 + + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids) + + if use_cudagraph: + out.fill_(0) + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids) + torch.cuda.synchronize() + graph.replay() + torch.cuda.synchronize() + + #print(f"{out.sum()=}") + #print(f"{ref_out.sum()=}") + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + + assert rel_diff < 0.03 diff --git a/tests/kernels/moe/test_block_int8.py b/tests/kernels/moe/test_block_int8.py new file mode 100644 index 00000000000..599f81247bb --- /dev/null +++ b/tests/kernels/moe/test_block_int8.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_block_int8.py +import itertools + +import pytest +import torch + +from tests.kernels.quant_utils import (native_per_token_group_quant_int8, + native_w8a8_block_matmul) +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.platforms import current_platform + +if current_platform.get_device_capability() < (7, 0): + pytest.skip("INT8 Triton requires CUDA 7.0 or higher", + allow_module_level=True) + +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + +DTYPES = [torch.half, torch.bfloat16] +M = [1, 33, 64, 222] +N = [128, 1024] +K = [256, 4096] +E = [8, 24] +TOP_KS = [2, 6] +# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] +BLOCK_SIZE = [[128, 128]] +SEEDS = [0] + + +# For test +def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): + """This function performs fused moe with block-wise quantization using + native torch.""" + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + _, block_k = block_shape[0], block_shape[1] + a_q, a_s = native_per_token_group_quant_int8(a, block_k) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + inter_out = native_w8a8_block_matmul(a_q[mask], + w1[i], + a_s[mask], + w1_s[i], + block_shape, + output_dtype=a.dtype) + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = native_per_token_group_quant_int8( + act_out, block_k) + act_out = act_out.to(torch.float32) + out[mask] = native_w8a8_block_matmul(act_out_q, + w2[i], + act_out_s, + w2_s[i], + block_shape, + output_dtype=a.dtype) + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + +@pytest.fixture(autouse=True, scope="module") +def setup_cuda(): + """Sets the default CUDA device for all tests in this module.""" + torch.set_default_device("cuda") + + +@pytest.mark.parametrize( + "M, N, K, E, topk, block_size, dtype, seed", + itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) +@torch.inference_mode() +def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): + """Tests the fused_moe kernel with W8A8 INT8 block quantization against a + native torch reference.""" + torch.manual_seed(seed) + # Use a smaller factor for scale initialization to prevent large + # values/overflow especially when output dtype might be float16 + factor_for_scale = 1e-2 + int8_info = torch.iinfo(torch.int8) + int8_max, int8_min = int8_info.max, int8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_fp32 = (torch.rand( + (E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max + w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) + + w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max + w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + + w1_s = (torch.rand( + (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale) + w2_s = (torch.rand( + (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale) + + score = torch.randn((M, E), dtype=dtype) + + # Set the context to avoid lots of warning spam. + with set_current_vllm_config(vllm_config): + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_int8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) + + # Check results + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.06 diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index ce420901e31..158100a0987 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -29,7 +29,10 @@ (224, 1024, 1536), (224, 3072, 1024), (224, 3072, 1536), - (1024 * 128, 1024, 1024), + (32768, 1024, 1024), + # These sizes trigger wrong answers. + #(7232, 2048, 5120), + #(40000, 2048, 5120), ] vllm_config = VllmConfig(parallel_config=ParallelConfig( @@ -232,8 +235,10 @@ def test_cutlass_moe_8_bit_no_graph( topk: int, per_act_token: bool, per_out_ch: bool, + monkeypatch, ): current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) @@ -274,8 +279,10 @@ def test_cutlass_moe_8_bit_cuda_graph( topk: int, per_act_token: bool, per_out_ch: bool, + monkeypatch, ): current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): dtype = torch.half @@ -329,8 +336,10 @@ def test_cutlass_moe_8_bit_EP( per_act_token: bool, per_out_channel: bool, ep_size: int, + monkeypatch, ): current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_channel) diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 2d7cf39a8cc..5bd450a47ef 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """ -Test DeepEP + DeepGEMM integration +Test DeepEP + DeepGEMM integration DeepGEMM are gemm kernels specialized for the fp8 block-quantized case. """ @@ -21,16 +21,12 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) from vllm.platforms import current_platform +from vllm.utils import cdiv from .deepep_utils import ProcessGroupInfo, parallel_launch has_deep_ep = importlib.util.find_spec("deep_ep") is not None - -try: - import deep_gemm - has_deep_gemm = True -except ImportError: - has_deep_gemm = False +has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None if has_deep_ep: from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 @@ -72,8 +68,7 @@ def per_block_cast_to_fp8( assert x.dim() == 2 m, n = x.shape x_padded = torch.zeros( - (deep_gemm.ceil_div(m, 128) * 128, - deep_gemm.ceil_div(n, block_size_n) * block_size_n), + (cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n), dtype=x.dtype, device=x.device) x_padded[:m, :n] = x @@ -210,7 +205,8 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, fused_experts = BatchedDeepGemmExperts(max_num_tokens=max_tokens_per_rank, world_size=pgi.world_size, dp_size=dp_size, - block_shape=test_config.block_size) + block_shape=test_config.block_size, + per_act_token_quant=True) mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk @@ -432,6 +428,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, """ Tests for High-Throughput DeepEP + DeepGemm integration. """ + import deep_gemm m, n, k = mnk current_platform.seed_everything(7) diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index 7e029ea9505..53085beeb1d 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -154,20 +154,25 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, deepep_ll_args = ll_args) if low_latency_mode: + # TODO(bnell): block_shape? fused_experts = BatchedTritonExperts( max_num_tokens=MAX_TOKENS_PER_RANK, world_size=pgi.world_size, dp_size=dp_size, - use_fp8_w8a8=is_quantized, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False) + FusedMoEQuantConfig.make( + act_dtype=torch.bfloat16, + use_fp8_w8a8=is_quantized, + ), + ) else: - fused_experts = TritonExperts(use_fp8_w8a8=is_quantized, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_channel_quant=False) + # TODO(bnell): block_shape? + fused_experts = TritonExperts( + FusedMoEQuantConfig.make( + act_dtype=torch.bfloat16, + use_fp8_w8a8=is_quantized, + per_act_token_quant=False, + ), + ) mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index bed374cf4d5..479c69342f2 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -4,6 +4,9 @@ Run `pytest tests/kernels/test_moe.py`. """ +import functools +from typing import Callable, Optional, Union + import pytest import torch from torch.nn import Parameter @@ -40,7 +43,76 @@ vllm_config.scheduler_config.max_model_len = 8192 -@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) +def run_moe_test( + baseline: Union[Callable, torch.Tensor], + moe_fn: Callable, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + padding: bool = False, + use_compile: bool = False, + use_cudagraph: bool = False, + atol: float = 2e-2, + rtol: float = 0, +) -> torch.Tensor: + if isinstance(baseline, torch.Tensor): + baseline_output = baseline + else: + baseline_output = baseline(a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map) + + # Pad the weight if moe padding is enabled + if padding: + w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] + w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] + + if use_compile: + moe_fn = torch.compile(moe_fn, backend="inductor", fullgraph=True) + torch._dynamo.mark_dynamic(a, 0) + torch._dynamo.mark_dynamic(score, 0) + + test_output = moe_fn(a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map) + + if use_cudagraph: + test_output.fill_(0) + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + test_output = moe_fn(a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map) + torch.cuda.synchronize() + graph.replay() + torch.cuda.synchronize() + + torch.testing.assert_close(test_output, + baseline_output, + atol=atol, + rtol=rtol) + + return baseline_output + + +@pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @@ -48,6 +120,7 @@ @pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("padding", [True, False]) +@pytest.mark.parametrize("chunk_size", [8192]) def test_fused_moe( m: int, n: int, @@ -57,7 +130,17 @@ def test_fused_moe( ep_size: int, dtype: torch.dtype, padding: bool, + chunk_size: int, + monkeypatch, ): + current_platform.seed_everything(7) + + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size)) + + # + # Setup test data + # + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 @@ -77,58 +160,70 @@ def test_fused_moe( else: e_map = None - m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_channel_quant=False, - block_shape=None) - - with set_current_vllm_config(vllm_config): - torch_output = torch_moe(a, w1, w2, score, topk, e_map) - iterative_output = iterative_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) - - # Pad the weight if moe padding is enabled - if padding: - w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() - w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() + # + # Setup test functions + # + + m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_act_token_quant=False, + block_shape=None) + + def m_fused_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) + return m_fused_moe_fn(a, + w1, + w2, + topk_weights, + topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map) + + fused_moe_fn = functools.partial(fused_moe, renormalize=False) + + # + # Run tests + # + runner = functools.partial( + run_moe_test, + a=a, + w1=w1, + w2=w2, + score=score, + topk=topk, + global_num_experts=e, + expert_map=e_map, + padding=padding, + ) - triton_output = fused_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) + # Note: for now use_compile will error out if the problem size is + # large enough to trigger chunking. I'm leaving the flag and + # setup code in case we are able to revisit this later. + use_compile = False - topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) - m_triton_output = m_fused_moe(a, - w1, - w2, - topk_weights, - topk_ids, - global_num_experts=e, - expert_map=e_map) + use_cudagraph = (n >= 1024 and k >= 1024 + and current_platform.is_cuda_alike()) - torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) - torch.testing.assert_close(m_triton_output, - torch_output, - atol=2e-2, - rtol=0) - torch.testing.assert_close(iterative_output, - torch_output, - atol=2e-2, - rtol=0) + with set_current_vllm_config(vllm_config): + baseline_output = runner(torch_moe, iterative_moe) + runner(baseline_output, + fused_moe_fn, + use_compile=use_compile, + use_cudagraph=use_cudagraph) + runner(baseline_output, + m_fused_moe, + use_compile=use_compile, + use_cudagraph=use_cudagraph) @pytest.mark.parametrize("m", [1, 32, 222]) @@ -238,7 +333,12 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, w1_zp=w1_qzeros if has_zp else None, w2_zp=w2_qzeros if has_zp else None, block_shape=[0, group_size]) - torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) + torch_output = torch_moe(a, + w1_ref, + w2_ref, + score, + topk, + expert_map=e_map) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) @@ -265,45 +365,49 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, pytest.skip("AITER ROCm test skip for float32") # Instantiate our and huggingface's MoE blocks - config = MixtralConfig() - hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") - vllm_moe = MixtralMoE( - num_experts=config.num_local_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - params_dtype=dtype, - tp_size=1, - dp_size=1, - ).cuda() - - # Load the weights - vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data - for i in range(config.num_local_experts): - weights = (hf_moe.experts[i].w1.weight.data, - hf_moe.experts[i].w3.weight.data) - vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) - vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data - - # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] - hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") - # vLLM uses 1D query [num_tokens, hidden_dim] - vllm_inputs = hf_inputs.flatten(0, 1) + with set_current_vllm_config(vllm_config): + config = MixtralConfig() + hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") + vllm_moe = MixtralMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + params_dtype=dtype, + tp_size=1, + dp_size=1, + ).cuda() + + # Load the weights + vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data + for i in range(config.num_local_experts): + weights = (hf_moe.experts[i].w1.weight.data, + hf_moe.experts[i].w3.weight.data) + vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) + vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data + + # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] + hf_inputs = torch.randn( + (1, 64, config.hidden_size)).to(dtype).to("cuda") + # vLLM uses 1D query [num_tokens, hidden_dim] + vllm_inputs = hf_inputs.flatten(0, 1) - # Pad the weight if moe padding is enabled - if padding: - vllm_moe.experts.w13_weight = Parameter(F.pad( - vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128], - requires_grad=False) - torch.cuda.empty_cache() - vllm_moe.experts.w2_weight = Parameter(F.pad( - vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128], - requires_grad=False) - torch.cuda.empty_cache() - - # Run forward passes for both MoE blocks - hf_states, _ = hf_moe.forward(hf_inputs) - vllm_states = vllm_moe.forward(vllm_inputs) + # Pad the weight if moe padding is enabled + if padding: + vllm_moe.experts.w13_weight = Parameter(F.pad( + vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., + 0:-128], + requires_grad=False) + torch.cuda.empty_cache() + vllm_moe.experts.w2_weight = Parameter(F.pad( + vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., + 0:-128], + requires_grad=False) + torch.cuda.empty_cache() + + # Run forward passes for both MoE blocks + hf_states, _ = hf_moe.forward(hf_inputs) + vllm_states = vllm_moe.forward(vllm_inputs) mixtral_moe_tol = { torch.float32: 1e-3, @@ -546,7 +650,12 @@ def test_fused_marlin_moe( topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) with set_current_vllm_config(vllm_config): - torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) + torch_output = torch_moe(a, + w_ref1, + w_ref2, + score, + topk, + expert_map=e_map) marlin_output = torch.ops.vllm.fused_marlin_moe( a, diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index 22482d9ca85..3f5412e7582 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -14,7 +14,7 @@ from vllm.platforms import current_platform if not current_platform.has_device_capability(100): - pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", + pytest.skip("Nvfp4 Requires compute capability of 10 or above.", allow_module_level=True) MNK_FACTORS = [ @@ -136,7 +136,7 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, device=w2.device, block_size=quant_blocksize) - torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None) + torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) torch.testing.assert_close(torch_output, cutlass_output, diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index d90202dfcb3..739bc560b87 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -6,9 +6,9 @@ import pytest import torch +from tests.kernels.utils import torch_experts from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.modular_kernel import ( @@ -93,7 +93,7 @@ def pplx_cutlass_moe( num_experts=num_experts, experts_per_token=topk, rank=rank, - world_size=pgi.world_size, + world_size=world_size, dp_size=dp_size, hidden_dim=hidden_dim, hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1 @@ -118,8 +118,6 @@ def pplx_cutlass_moe( pgi.world_size, rank, dp_size, - quant_dtype=torch.float8_e4m3fn, - per_act_token=per_act_token, ) experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size, @@ -164,22 +162,6 @@ def pplx_cutlass_moe( vllm_config.scheduler_config.max_model_len = 8192 -def torch_moe2(a, w1, w2, topk_weight, topk_ids): - M, K = a.shape - topk = topk_ids.shape[1] - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) - num_experts = w1.shape[0] - for i in range(num_experts): - mask = (topk_ids == i).view(-1) - if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - - return (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - - def _pplx_moe( pgi: ProcessGroupInfo, dp_size: int, @@ -210,8 +192,8 @@ def _pplx_moe( group_name = cpu_group.group_name with set_current_vllm_config(vllm_config): - torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights, - topk_ids) + torch_output = torch_experts(a_full, w1_full, w2_full, topk_weights, + topk_ids) pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale, w2_scale, topk_weights, topk_ids, a1_scale, out_dtype, per_act_token, diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 2d6a8f39cec..c817bf20d90 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -18,16 +18,18 @@ except ImportError: has_pplx = False +from tests.kernels.moe.utils import make_test_weights, naive_batched_moe +from tests.kernels.utils import torch_experts from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import override_config +from vllm.model_executor.layers.fused_moe import fused_topk, override_config +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts) -from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk, - get_default_config) + BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) +from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) from vllm.platforms import current_platform +from vllm.utils import round_up from .deepep_utils import ProcessGroupInfo, parallel_launch @@ -144,48 +146,6 @@ def torch_batched_moe( return torch_finalize(out, topk_weight, topk_ids) -def batched_moe( - a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weight: torch.Tensor, - topk_ids: torch.Tensor, -) -> torch.Tensor: - num_experts = w1.shape[0] - - fused_experts = FusedMoEModularKernel( - BatchedPrepareAndFinalize(max_num_tokens=a.shape[0], - world_size=1, - dp_size=1, - rank=0), - BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1)) - - return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts) - - -# Note: same as torch_moe but with fused_topk factored out. -def torch_moe2( - a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weight: torch.Tensor, - topk_ids: torch.Tensor, -) -> torch.Tensor: - M, K = a.shape - topk = topk_ids.shape[1] - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) - num_experts = w1.shape[0] - for i in range(num_experts): - mask = (topk_ids == i).view(-1) - if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - - return (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - - @pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 512, 1024]) @@ -209,9 +169,9 @@ def test_fused_moe_batched_experts( with set_current_vllm_config(vllm_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids) torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) - batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids) + batched_output = naive_batched_moe(a, w1, w2, topk_weight, topk_ids) torch.testing.assert_close(baseline_output, torch_output, @@ -249,7 +209,6 @@ def pplx_prepare_finalize( topk = topk_ids.shape[1] num_tokens, hidden_dim = a.shape - block_size = 128 device = pgi.device rank = pgi.rank world_size = pgi.world_size @@ -264,9 +223,7 @@ def pplx_prepare_finalize( dp_size=dp_size, hidden_dim=hidden_dim, hidden_dim_bytes=hidden_dim * a.dtype.itemsize, - hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else - ((hidden_dim + block_size - 1) // block_size * - torch.float32.itemsize)), + hidden_dim_scale_bytes=0, ) if group_name is None: @@ -283,7 +240,6 @@ def pplx_prepare_finalize( world_size, rank, dp_size, - a.dtype, ) a_chunk = chunk_by_rank(a, rank, world_size).to(device) @@ -299,6 +255,7 @@ def pplx_prepare_finalize( num_experts, None, False, + FusedMoEQuantConfig(), ) b_a = b_a * 1.5 @@ -373,10 +330,11 @@ def _pplx_prepare_finalize( # TODO (bnell): this test point does not work for odd M due to how the test is # written, not due to limitations of the pplx kernels. The pplx_moe # test below is able to deal with odd M. +# TODO (bnell) add fp8 tests @pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("use_internode", [False]) @requires_pplx @@ -409,18 +367,31 @@ def pplx_moe( w2: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor, - use_compile: bool = True, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + qtype: Optional[torch.dtype] = None, + per_act_token_quant=False, + block_shape: Optional[list[int]] = None, + use_compile: bool = False, use_cudagraphs: bool = True, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize) + PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes) device = torch.device("cuda", rank) hidden_dim = a.shape[1] num_experts = w1.shape[0] - block_size = 128 topk = topk_ids.shape[1] - max_num_tokens = rank_chunk(a.shape[0], 0, world_size) + max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 64) + + hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes( + max_num_tokens, + hidden_dim, + a.dtype, + qtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ) args = dict( max_num_tokens=max_num_tokens, @@ -430,10 +401,8 @@ def pplx_moe( world_size=world_size, dp_size=dp_size, hidden_dim=hidden_dim, - hidden_dim_bytes=hidden_dim * a.dtype.itemsize, - hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else - ((hidden_dim + block_size - 1) // block_size * - torch.float32.itemsize)), + hidden_dim_bytes=hidden_dim_bytes, + hidden_dim_scale_bytes=scale_bytes, ) if group_name is None: @@ -452,9 +421,14 @@ def pplx_moe( dp_size, ) - experts = BatchedTritonExperts(max_num_tokens=a.shape[0], + experts = BatchedTritonExperts(max_num_tokens=max_num_tokens, world_size=world_size, - dp_size=dp_size) + dp_size=dp_size, + FusedMoEQuantConfig.make( + use_fp8_w8a8=qtype==torch.float8_e4m3fn, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape), + ) fused_experts = FusedMoEModularKernel( prepare_finalize, @@ -470,10 +444,24 @@ def pplx_moe( w1_chunk = chunk_by_rank(w1, rank, world_size).to(device) w2_chunk = chunk_by_rank(w2, rank, world_size).to(device) + # TODO scale chunk function + if w1_scale is not None: + w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device) + w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device) + else: + w1_scale_chunk = None + w2_scale_chunk = None + + # Note: for now use_compile will error out if the problem size is + # large enough to trigger chunking. I'm leaving the flag and + # setup code in case we are able to revisit this later. if use_compile: _fused_experts = torch.compile(fused_experts, backend='inductor', fullgraph=True) + torch._dynamo.mark_dynamic(a_chunk, 0) + torch._dynamo.mark_dynamic(chunk_topk_weight, 0) + torch._dynamo.mark_dynamic(chunk_topk_ids, 0) else: _fused_experts = fused_experts @@ -482,6 +470,8 @@ def pplx_moe( w2_chunk, chunk_topk_weight, chunk_topk_ids, + w1_scale=w1_scale_chunk, + w2_scale=w2_scale_chunk, global_num_experts=num_experts) if use_cudagraphs: @@ -494,6 +484,8 @@ def pplx_moe( w2_chunk, chunk_topk_weight, chunk_topk_ids, + w1_scale=w1_scale_chunk, + w2_scale=w2_scale_chunk, global_num_experts=num_experts) torch.cuda.synchronize() @@ -522,9 +514,9 @@ def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): rank=rank, ) - experts = BatchedExperts(max_num_tokens=a.shape[0], - world_size=1, - dp_size=1) + experts = NaiveBatchedExperts(max_num_tokens=a.shape[0], + world_size=1, + dp_size=1) fused_experts = FusedMoEModularKernel( prepare_finalize, @@ -556,7 +548,12 @@ def _pplx_moe( w2: torch.Tensor, score: torch.Tensor, topk: int, - use_internode: bool, + w1_s: Optional[torch.Tensor] = None, + w2_s: Optional[torch.Tensor] = None, + qtype: Optional[torch.dtype] = None, + per_act_token_quant: bool = False, + block_shape: Optional[list[int]] = None, + use_internode: bool = False, ): if use_internode: uid = nvshmem_get_unique_id( @@ -574,11 +571,28 @@ def _pplx_moe( moe_config = get_default_config(m, e, n, k, topk, a.dtype, False) + device = torch.device("cuda", pgi.rank) + a = a.to(device) + w1 = w1.to(device) + w2 = w2.to(device) + w1_s = w1_s.to(device) if w1_s is not None else None + w2_s = w2_s.to(device) if w2_s is not None else None + with set_current_vllm_config(vllm_config), override_config(moe_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + torch_output = torch_experts(a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s, + quant_dtype=qtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape) pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size, - a, w1, w2, topk_weight, topk_ids) + a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, + qtype, per_act_token_quant, block_shape) # TODO (bnell): fix + re-enable #batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, # topk_ids) @@ -596,8 +610,10 @@ def _pplx_moe( @pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) # torch.float8_e4m3fn, @pytest.mark.parametrize("world_dp_size", [[2, 1]]) +@pytest.mark.parametrize("per_act_token_quant", [False, True]) +@pytest.mark.parametrize("block_shape", [None, [128, 128]]) @pytest.mark.parametrize("use_internode", [False]) @requires_pplx def test_pplx_moe( @@ -606,15 +622,33 @@ def test_pplx_moe( topk: int, dtype: torch.dtype, world_dp_size: tuple[int, int], + per_act_token_quant: bool, + block_shape: Optional[list[int]], use_internode: bool, ): current_platform.seed_everything(7) m, n, k = mnk world_size, dp_size = world_dp_size - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device="cuda", dtype=dtype) + + if dtype == torch.float8_e4m3fn: + use_fp8_w8a8 = True + quant_dtype = dtype + else: + use_fp8_w8a8 = False + quant_dtype = None + + if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None: + pytest.skip("Skip quantization test for non-quantized type") + + a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 + score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) + + _, w1, w1_s, _, w2, w2_s = make_test_weights(e, + n, + k, + quant_dtype=quant_dtype, + block_shape=block_shape) parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, + w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape, use_internode) diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py new file mode 100644 index 00000000000..b19591eac1c --- /dev/null +++ b/tests/kernels/moe/utils.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +from vllm.model_executor.layers.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) +from vllm.utils import round_up + + +def triton_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + quant_type: Optional[torch.dtype] = None, + per_act_token_quant=False, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: + return fused_experts(a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_scale, + w2_scale=w2_scale, + per_channel_quant=per_act_token_quant, + use_fp8_w8a8=quant_type == torch.float8_e4m3fn, + block_shape=block_shape) + + +def batched_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + qtype: Optional[torch.dtype] = None, + per_act_token: bool = False, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: + max_num_tokens = round_up(a.shape[0], 64) + + fused_experts = FusedMoEModularKernel( + BatchedPrepareAndFinalize(max_num_tokens, + world_size=1, + dp_size=1, + rank=0), + BatchedTritonExperts(max_num_tokens=max_num_tokens, + world_size=1, + dp_size=1, + FusedMoEQuantConfig.make( + use_fp8_w8a8=qtype == torch.float8_e4m3fn, + per_act_token_quant=per_act_token, + block_shape=block_shape), + ) + + return fused_experts(a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_scale, + w2_scale=w2_scale) + + +def naive_batched_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, +) -> torch.Tensor: + num_experts = w1.shape[0] + + fused_experts = FusedMoEModularKernel( + BatchedPrepareAndFinalize(a.shape[0], world_size=1, dp_size=1, rank=0), + NaiveBatchedExperts(max_num_tokens=a.shape[0], dp_size=1, + world_size=1)) + + return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts) + + +def per_block_cast_to_fp8( + x: torch.Tensor, + block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + from vllm.utils import cdiv + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros( + (cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n), + dtype=x.dtype, + device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() + scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + return x_scaled_sub, scales + + +def chunk_scales(scales: Optional[torch.Tensor], start: int, + end: int) -> Optional[torch.Tensor]: + if scales is not None: + if scales.numel() == 1: + return scales + else: + return scales[start:end] + return None + + +def make_quantized_test_activations( + E: int, + m: int, + k: int, + in_dtype: torch.dtype, + quant_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None, + per_act_token_quant: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert not per_act_token_quant, "NYI" + + a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10 + a_q = a + a_scale = None + + if quant_dtype is not None: + assert quant_dtype == torch.float8_e4m3fn, "only fp8 supported" + a_q = torch.zeros_like(a, dtype=quant_dtype) + a_scale = [None] * E + for e in range(E): + if block_shape is not None: + a_q[e], a_scale[e] = per_token_group_quant_fp8( + a[e], block_shape[1]) + else: + a_tmp, a_scale[e] = per_token_group_quant_fp8( + a[e].view(1, -1), a[e].numel()) + a_q[e] = a_tmp.view(*a[e].shape) + a_scale = torch.stack(a_scale) + + return a, a_q, a_scale + + +def make_test_weights( + e: int, + n: int, + k: int, + in_dtype: torch.dtype = torch.bfloat16, + quant_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, + torch.Tensor, Optional[torch.Tensor]]: + w1_16 = torch.randn((e, 2 * n, k), device="cuda", dtype=in_dtype) / 15 + w2_16 = torch.randn((e, k, n), device="cuda", dtype=in_dtype) / 15 + + if quant_dtype is not None: + assert quant_dtype == torch.float8_e4m3fn, "only fp8 supported" + w1_l = [None] * e + w2_l = [None] * e + w1_s_l = [None] * e + w2_s_l = [None] * e + for idx in range(e): + if block_shape is not None: + w1_l[idx], w1_s_l[idx] = per_block_cast_to_fp8( + w1_16[idx], + block_shape[1], + ) + w2_l[idx], w2_s_l[idx] = per_block_cast_to_fp8( + w2_16[idx], + block_shape[1], + ) + else: + tmp, w1_s_l[idx] = per_token_group_quant_fp8( + w1_16[idx].view(1, -1), w1_16[idx].numel()) + w1_l[idx] = tmp.view(*w1_16[idx].shape) + + tmp, w2_s_l[idx] = per_token_group_quant_fp8( + w2_16[idx].view(1, -1), w2_16[idx].numel()) + w2_l[idx] = tmp.view(*w2_16[idx].shape) + + w1 = torch.stack(w1_l) + w2 = torch.stack(w2_l) + w1_s = torch.stack(w1_s_l) + w2_s = torch.stack(w2_s_l) + if w1_s.ndim == 2: + assert w1_s.shape[-1] == 1 + w1_s = w1_s.view(-1, 1, 1) + w2_s = w2_s.view(-1, 1, 1) + + if block_shape is not None: + block_n, block_k = block_shape + n_tiles_w1 = ((2 * n) + block_n - 1) // block_n + k_tiles_w1 = (k + block_k - 1) // block_k + n_tiles_w2 = (k + block_n - 1) // block_n + k_tiles_w2 = (n + block_k - 1) // block_k + assert w1_s.shape == (e, n_tiles_w1, k_tiles_w1) + assert w2_s.shape == (e, n_tiles_w2, k_tiles_w2) + else: + w1 = w1_16 + w2 = w2_16 + w1_s = None + w2_s = None + + return w1_16, w1, w1_s, w2_16, w2, w2_s diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 0840cc7b54f..3c50ef1cff8 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -6,6 +6,7 @@ import torch from vllm.platforms import current_platform +from vllm.utils import cdiv # Using the default value (240.0) from pytorch will cause accuracy # issue on dynamic quantization models. Here use 224.0 for rocm. @@ -94,9 +95,15 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ return ref_out, ref_scale.view((1, )) -def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, - As: torch.Tensor, Bs: torch.Tensor, block_size, - output_dtype): +def native_w8a8_block_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, + compute_type: torch.dtype = torch.float32, +) -> torch.Tensor: """This function performs matrix multiplication with block-wise quantization using native torch. It is agnostic to the input data type and can be used for both int8 and @@ -106,8 +113,8 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, `Bs` (float32). The output is returned in the specified `output_dtype`. """ - A = A.to(torch.float32) - B = B.to(torch.float32) + A = A.to(compute_type) + B = B.to(compute_type) assert A.shape[-1] == B.shape[-1] assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 assert len(block_size) == 2 @@ -122,11 +129,11 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, As = As.reshape(M, As.shape[-1]) n_tiles = (N + block_n - 1) // block_n k_tiles = (K + block_k - 1) // block_k - assert n_tiles == Bs.shape[0] - assert k_tiles == Bs.shape[1] + assert n_tiles == Bs.shape[0], f"{n_tiles} == {Bs.shape[0]}" + assert k_tiles == Bs.shape[1], f"{k_tiles} == {Bs.shape[1]}" C_shape = (M, N) - C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + C = torch.zeros(C_shape, dtype=compute_type, device=A.device) A_tiles = [ A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) @@ -152,3 +159,112 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, C = C.reshape(origin_C_shape).to(output_dtype) return C + + +def native_per_token_group_quant_fp8(x, + group_size, + eps=1e-10, + dtype=torch.float8_e4m3fn): + """Function to perform per-token-group quantization on an input tensor + `x` using native torch.""" + assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot " + "be divisible by `group_size`") + assert x.is_contiguous(), "`x` is not contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_ = x.reshape(x.numel() // group_size, group_size) + amax = x_.abs().max(dim=-1, + keepdim=True)[0].clamp(min=eps).to(torch.float32) + x_s = amax / fp8_max + x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) + x_q = x_q.reshape(x.shape) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) + + return x_q, x_s + + +def native_per_token_group_quant_int8(x, + group_size, + eps=1e-10, + dtype=torch.int8): + """Function to perform per-token-group quantization on an input tensor + `x` using native torch. + + It converts the tensor values into int8 values and returns the + quantized tensor along with the scaling factor used for quantization. + """ + assert (x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + iinfo = torch.iinfo(dtype) + int8_min = iinfo.min + int8_max = iinfo.max + + x_ = x.reshape(x.numel() // group_size, group_size) + # Use float32 for scale calculation for stability + amax = x_.abs().max(dim=-1, + keepdim=True)[0].clamp(min=eps).to(torch.float32) + x_s = amax / int8_max + x_q = (x_.to(torch.float32) / x_s).round().clamp( + min=int8_min, max=int8_max).to(dtype) # Round before clamping + x_q = x_q.reshape(x.shape) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) + + return x_q, x_s + + +def per_block_cast_to_fp8( + x: torch.Tensor, + block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros( + (cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n), + dtype=x.dtype, + device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() + scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + return x_scaled_sub, scales + + +def native_batched_masked_quant_matmul( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + num_expert_tokens: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + block_shape: Optional[list[int]], +) -> torch.Tensor: + num_expert_tokens_cpu = num_expert_tokens.clone() + num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") + num_experts = num_expert_tokens.size(0) + + f32 = torch.float32 + + for e in range(num_experts): + num_tokens = num_expert_tokens_cpu[e] + if A.dtype.itemsize == 1 and block_shape is not None: + assert A_scale is not None and B_scale is not None + tmp = native_w8a8_block_matmul(A[e], B[e], A_scale[e], B_scale[e], + block_shape, C.dtype) + C[e, :num_tokens, :] = tmp[:num_tokens, :] + elif A.dtype.itemsize == 1 and block_shape is None: + assert A_scale is not None and B_scale is not None + C[e, :num_tokens, :] = ( + (A[e, :num_tokens, :].to(f32) * A_scale[e]).to(C.dtype) + @ (B[e].transpose(0, 1).to(f32) * B_scale[e]).to(C.dtype)) + else: + assert A_scale is None + assert B_scale is None + C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) + + return C diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index eec59573792..42d5526dc21 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -7,16 +7,10 @@ import pytest import torch -from tests.kernels.quant_utils import native_w8a8_block_matmul -from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm_shape, deep_gemm_moe_fp8) -from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, modular_triton_fused_moe) -from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size) +from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, + native_w8a8_block_matmul, + per_block_cast_to_fp8) +from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform @@ -46,78 +40,10 @@ K = [256, 3884, 4096, 13824, 16384] # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. -M_moe = [1, 2, 7, 83, 128, 2048, 1024 * 128] -M_moe_dg = [128, 192, 1335, 2048] -N_moe = [128, 256, 1024, 4608] # [13824] -K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] -E = [2, 8, 16, 24] # [128, 256] -TOP_KS = [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] - -def native_per_token_group_quant_fp8(x, - group_size, - eps=1e-10, - dtype=torch.float8_e4m3fn): - """Function to perform per-token-group quantization on an input tensor - `x` using native torch.""" - assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot " - "be divisible by `group_size`") - assert x.is_contiguous(), "`x` is not contiguous" - - finfo = torch.finfo(dtype) - fp8_min = finfo.min - fp8_max = finfo.max - - x_ = x.reshape(x.numel() // group_size, group_size) - amax = x_.abs().max(dim=-1, - keepdim=True)[0].clamp(min=eps).to(torch.float32) - x_s = amax / fp8_max - x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) - x_q = x_q.reshape(x.shape) - x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) - - return x_q, x_s - - -def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): - """Fused moe with block-wise quantization using native torch.""" - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - - _, block_k = block_shape[0], block_shape[1] - a_q, a_s = native_per_token_group_quant_fp8(a, block_k) - a_q = a_q.to(torch.float32) - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - inter_out = native_w8a8_block_matmul(a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - block_shape, - output_dtype=a.dtype) - act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = native_per_token_group_quant_fp8( - act_out, block_k) - act_out = act_out.to(torch.float32) - out[mask] = native_w8a8_block_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - block_shape, - output_dtype=a.dtype) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) - - # Skip all tests if CUDA is not available pytest.importorskip("torch.cuda") @@ -177,111 +103,6 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 -@pytest.mark.parametrize( - "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, - SEEDS)) -@torch.inference_mode() -def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): - if topk > E: - pytest.skip(f"Skipping test; topk={topk} > E={E}") - - torch.manual_seed(seed) - factor_for_scale = 1e-2 - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - a = torch.randn((M, K), dtype=dtype) / 10 - - w1_bf16 = (torch.rand( - (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - del w1_bf16 - - w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - del w2_bf16 - - block_n, block_k = block_size[0], block_size[1] - n_tiles_w1 = (2 * N + block_n - 1) // block_n - n_tiles_w2 = (K + block_n - 1) // block_n - k_tiles_w1 = (K + block_k - 1) // block_k - k_tiles_w2 = (N + block_k - 1) // block_k - - w1_s = torch.rand( - (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale - w2_s = torch.rand( - (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale - - score = torch.randn((M, E), dtype=dtype) - - m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_channel_quant=False, - block_shape=block_size) - - # Set the context to avoid lots of warning spam. - with set_current_vllm_config(vllm_config): - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - ) - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) - - topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) - m_out = m_fused_moe(a, - w1, - w2, - topk_weights, - topk_ids, - global_num_experts=E, - w1_scale=w1_s, - w2_scale=w2_s) - - #print(f"{out.sum()=}") - #print(f"{ref_out.sum()=}") - - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - assert rel_diff < 0.03 - - rel_diff = (torch.mean( - torch.abs(m_out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - assert rel_diff < 0.03 - - -def per_block_cast_to_fp8( - x: torch.Tensor, - block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros( - (deep_gemm.ceil_div(m, 128) * 128, - deep_gemm.ceil_div(n, block_size_n) * block_size_n), - dtype=x.dtype, - device=x.device) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() - scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) - return x_scaled_sub, scales - - @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) @@ -324,152 +145,3 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.001 - - -def fp8_perm(m, idx): - if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: - return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) - else: - return m[idx, ...] - - -def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): - M, K = a.shape - - sorted_token_ids, m_indices, num_pad = moe_align_block_size( - topk_ids, block_m, num_groups, None, pad_sorted_ids=True) - - num_tokens = topk * M - - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) - m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) - inv_perm = torch.argsort(sorted_token_ids)[:M * topk] - - a = fp8_perm(a, sorted_token_ids // topk) - if a_s is not None: - a_s = a_s[sorted_token_ids // topk] - - return a, a_s, m_indices, inv_perm - - -def _moe_unpermute(out, inv_perm, topk, K, topk_weight): - M = topk_weight.shape[0] - out = out[inv_perm, ...] - tmp_out = out.view(-1, topk, K) - return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - - -def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, - block_shape): - """Fused moe with block-wise quantization using DeepGemm grouped gemm.""" - num_groups = w1.shape[0] - M, K = a.shape - N = w2.shape[-1] - - topk_weight, topk_ids, token_expert_indices = fused_topk( - a, score.float(), topk, False) - - block_m = deep_gemm.get_m_alignment_for_contiguous_layout() - - _, block_k = block_shape[0], block_shape[1] - - a_q, a_s = per_token_group_quant_fp8(a, block_m) - - a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids, - num_groups, topk, block_m) - - inter_out = torch.zeros((a_q.shape[0], N * 2), - dtype=torch.bfloat16, - device=a.device) - - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), - inter_out, m_indices) - - act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - - out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device) - - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - - final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight) - - return final_out - - -@pytest.mark.parametrize( - "M,N,K,E,topk,seed", - itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS)) -@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") -@torch.inference_mode() -def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): - - block_m = deep_gemm.get_m_alignment_for_contiguous_layout() - block_size = [block_m, block_m] - dtype = torch.bfloat16 - - if topk > E: - pytest.skip(f"Skipping test: topk={topk} > E={E}") - - if not _valid_deep_gemm_shape(M, N, K): - pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") - - torch.manual_seed(seed) - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - a = torch.randn((M, K), dtype=dtype) / 10 - - w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * - fp8_max).clamp(min=fp8_min, max=fp8_max) - - w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * - fp8_max).clamp(min=fp8_min, max=fp8_max) - - score = torch.randn((M, E), dtype=dtype) - - block_n, block_k = block_size[0], block_size[1] - n_tiles_w1 = ((2 * N) + block_n - 1) // block_n - k_tiles_w1 = (K + block_k - 1) // block_k - n_tiles_w2 = (K + block_n - 1) // block_n - k_tiles_w2 = (N + block_k - 1) // block_k - - w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) - w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) - - w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) - w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) - - w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() - w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() - - assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128) - assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] - - for i in range(E): - w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) - w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - - # Set the context to avoid lots of warning spam. - with set_current_vllm_config(vllm_config): - if M >= 128: - ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, - score, topk, block_size) - else: - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, - topk, block_size) - - topk_weights, topk_ids, token_expert_indices = fused_topk( - a, score.float(), topk, False) - - out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) - - #print(f"{out.sum()=}") - #print(f"{ref_out.sum()=}") - - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - - assert rel_diff < 0.03 diff --git a/tests/kernels/quantization/test_block_int8.py b/tests/kernels/quantization/test_block_int8.py index fa2c9f890d6..fac82cf9c8b 100644 --- a/tests/kernels/quantization/test_block_int8.py +++ b/tests/kernels/quantization/test_block_int8.py @@ -8,9 +8,7 @@ import torch from tests.kernels.quant_utils import native_w8a8_block_matmul -from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.int8_utils import ( w8a8_block_int8_matmul) from vllm.platforms import current_platform @@ -23,82 +21,10 @@ vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_model_len = 8192 - -# For test -def native_per_token_group_quant_int8(x, - group_size, - eps=1e-10, - dtype=torch.int8): - """Function to perform per-token-group quantization on an input tensor - `x` using native torch. - - It converts the tensor values into int8 values and returns the - quantized tensor along with the scaling factor used for quantization. - """ - assert (x.shape[-1] % group_size == 0 - ), "the last dimension of `x` cannot be divisible by `group_size`" - assert x.is_contiguous(), "`x` is not contiguous" - - iinfo = torch.iinfo(dtype) - int8_min = iinfo.min - int8_max = iinfo.max - - x_ = x.reshape(x.numel() // group_size, group_size) - # Use float32 for scale calculation for stability - amax = x_.abs().max(dim=-1, - keepdim=True)[0].clamp(min=eps).to(torch.float32) - x_s = amax / int8_max - x_q = (x_.to(torch.float32) / x_s).round().clamp( - min=int8_min, max=int8_max).to(dtype) # Round before clamping - x_q = x_q.reshape(x.shape) - x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) - - return x_q, x_s - - -# For test -def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): - """This function performs fused moe with block-wise quantization using - native torch.""" - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - - _, block_k = block_shape[0], block_shape[1] - a_q, a_s = native_per_token_group_quant_int8(a, block_k) - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - inter_out = native_w8a8_block_matmul(a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - block_shape, - output_dtype=a.dtype) - act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = native_per_token_group_quant_int8( - act_out, block_k) - act_out = act_out.to(torch.float32) - out[mask] = native_w8a8_block_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - block_shape, - output_dtype=a.dtype) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) - - DTYPES = [torch.half, torch.bfloat16] M = [1, 33, 64, 222] N = [128, 1024] K = [256, 4096] -E = [8, 24] -TOP_KS = [2, 6] # BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] BLOCK_SIZE = [[128, 128]] SEEDS = [0] @@ -140,63 +66,3 @@ def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed): torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.001 - - -@pytest.mark.parametrize( - "M, N, K, E, topk, block_size, dtype, seed", - itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) -@torch.inference_mode() -def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): - """Tests the fused_moe kernel with W8A8 INT8 block quantization against a - native torch reference.""" - torch.manual_seed(seed) - # Use a smaller factor for scale initialization to prevent large - # values/overflow especially when output dtype might be float16 - factor_for_scale = 1e-2 - int8_info = torch.iinfo(torch.int8) - int8_max, int8_min = int8_info.max, int8_info.min - - a = torch.randn((M, K), dtype=dtype) / 10 - - w1_fp32 = (torch.rand( - (E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max - w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) - - w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max - w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) - - block_n, block_k = block_size[0], block_size[1] - n_tiles_w1 = (2 * N + block_n - 1) // block_n - n_tiles_w2 = (K + block_n - 1) // block_n - k_tiles_w1 = (K + block_k - 1) // block_k - k_tiles_w2 = (N + block_k - 1) // block_k - - w1_s = (torch.rand( - (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale) - w2_s = (torch.rand( - (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale) - - score = torch.randn((M, E), dtype=dtype) - - # Set the context to avoid lots of warning spam. - with set_current_vllm_config(vllm_config): - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_int8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - ) - ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) - - # Check results - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - assert rel_diff < 0.06 diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index d1db6a8eb1b..85ca4974610 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -13,8 +13,11 @@ import torch from torch._prims_common import TensorLikeType +from tests.kernels.quant_utils import native_w8a8_block_matmul from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) from vllm.platforms.interface import _Backend from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) @@ -1054,23 +1057,84 @@ def compute_max_diff(output, output_ref): torch.abs(output_ref)) -def torch_moe(a, w1, w2, score, topk, expert_map): - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - topk_weight = topk_weight.view(-1) +def torch_experts( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + quant_dtype: Optional[torch.dtype] = None, + per_act_token_quant=False, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: + assert (global_num_experts == -1 + or (global_num_experts == w1.shape[0] and expert_map is None) + or (expert_map is not None + and global_num_experts == expert_map.shape[0])) + + M, K = a.shape + topk = topk_ids.shape[1] + + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + + out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) + + a, a_scale = moe_kernel_quantize_input(a, None, quant_dtype, + per_act_token_quant, block_shape) + + num_experts = w1.shape[0] + topk_ids = topk_ids.view(-1) if expert_map is not None: topk_ids = expert_map[topk_ids] - for i in range(w1.shape[0]): + + for i in range(num_experts): mask = topk_ids == i if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + if quant_dtype is None: + tmp1 = a[mask] @ w1[i].transpose(0, 1) + tmp2 = SiluAndMul()(tmp1) + out[mask] = tmp2 @ w2[i].transpose(0, 1) + elif block_shape is not None: + assert (a_scale is not None and w1_scale is not None + and w2_scale is not None) + tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask], + w1_scale[i], block_shape, + out.dtype) + tmp2 = SiluAndMul()(tmp1) + tmp2, b_scale = moe_kernel_quantize_input( + tmp2, None, quant_dtype, per_act_token_quant, block_shape) + + out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale, + w2_scale[i], block_shape, + out.dtype) + else: + compute_type = torch.bfloat16 + tmp1 = a[mask].to(compute_type) @ w1[i].transpose( + 0, 1).to(compute_type) + tmp2 = SiluAndMul()(tmp1) + out[mask] = (tmp2 @ w2[i].transpose(0, 1).to(compute_type)).to( + out.dtype) + + return (out.view(M, -1, w2.shape[1]) * + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + + +def torch_moe(a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts, + expert_map) def torch_moe_single(a, w, score, topk): diff --git a/vllm/envs.py b/vllm/envs.py index 921052821ee..41a1cfb5e87 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -942,6 +942,7 @@ def factorize(name: str): "VLLM_DP_RANK", "VLLM_DP_SIZE", "VLLM_USE_STANDALONE_COMPILE", + "VLLM_FUSED_MOE_CHUNK_SIZE", ] for key in environment_variables_to_hash: if key in environment_variables: diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 2bdc96e297c..3d40879b4cc 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -4,8 +4,12 @@ from contextlib import contextmanager from typing import Any, Optional +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize) from vllm.triton_utils import HAS_TRITON _config: Optional[dict[str, Any]] = None @@ -26,8 +30,12 @@ def get_config() -> Optional[dict[str, Any]]: __all__ = [ "FusedMoE", + "FusedMoEConfig", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported", + "FusedMoEPermuteExpertsUnpermute", + "FusedMoEActivationFormat", + "FusedMoEPrepareAndFinalize", "override_config", "get_config", ] @@ -36,11 +44,21 @@ def get_config() -> Optional[dict[str, Any]]: # import to register the custom ops import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa import vllm.model_executor.layers.fused_moe.fused_moe # noqa + from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + BatchedDeepGemmExperts) + from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 + BatchedTritonOrDeepGemmExperts) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp4, cutlass_moe_fp8) + CutlassExpertsFp8, cutlass_moe_fp4, cutlass_moe_fp8) + from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + DeepGemmExperts) + from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts) from vllm.model_executor.layers.fused_moe.fused_moe import ( TritonExperts, fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) + from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts) __all__ += [ "fused_moe", @@ -50,5 +68,11 @@ def get_config() -> Optional[dict[str, Any]]: "grouped_topk", "cutlass_moe_fp8", "cutlass_moe_fp4", + "CutlassExpertsFp8", "TritonExperts", + "BatchedTritonExperts", + "DeepGemmExperts", + "BatchedDeepGemmExperts", + "TritonOrDeepGemmExperts", + "BatchedTritonOrDeepGemmExperts", ] diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 5492399efdf..f3ac127bb57 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -6,6 +6,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, per_token_group_quant_fp8) @@ -17,28 +18,44 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # The Deep Gemm kernels only support block size of 128 - DEEPGEMM_BLOCK_SHAPE = 128 - - def __init__(self, max_num_tokens: int, world_size: int, dp_size: int, - block_shape: list[int]): + DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128] + + def __init__(self, + max_num_tokens: int, + world_size: int, + dp_size: int, + block_shape: list[int], + per_act_token_quant=False): """ max_num_tokens: Maximum number of tokens from a DP Rank world_size: Number of EP ranks dp_size: Number of data-parallel ranks block_shape: Block quantization block shape """ - super().__init__() + super().__init__( + FusedMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + )) + assert self.block_shape == self.DEEPGEMM_BLOCK_SHAPE self.max_num_tokens = max_num_tokens self.world_size = world_size self.dp_size = dp_size - self.block_shape = block_shape - assert (len(self.block_shape) == 2 and all( - [v == self.DEEPGEMM_BLOCK_SHAPE for v in self.block_shape])) + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) def supports_chunking(self) -> bool: return False + def supports_expert_map(self) -> bool: + return False + def workspace_shapes( self, a: torch.Tensor, @@ -86,6 +103,7 @@ def apply( ): import deep_gemm as dg assert hidden_states.ndim == 3 + assert self.block_shape is not None a1q = hidden_states _, N, K = w1.size() diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 822cda8205b..acc4df78586 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -6,57 +6,49 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts) +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts) class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, - max_num_tokens: int, - world_size: int, - dp_size: int, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, - block_shape: Optional[list[int]] = None, - allow_deep_gemm: bool = False): - super().__init__() - assert not use_int8_w8a8, "NYI" - assert not use_int8_w8a16, "NYI" - assert not use_int4_w4a16, "NYI" + def __init__( + self, + max_num_tokens: int, + world_size: int, + dp_size: int, + quant_config: Optional[FusedMoEQuantConfig] = None, + allow_deep_gemm: bool = False + ): + assert quant_config is None or (not quant_config.use_int8_w8a and + not quant_config.use_int8_w8a16 and + not quant_config.use_int4_w4a16), "NYI" + super().__init__(quant_config) self.max_num_tokens = max_num_tokens self.world_size = world_size self.dp_size = dp_size - self.use_fp8_w8a8 = use_fp8_w8a8 - self.use_int8_w8a8 = use_int8_w8a8 - self.use_int8_w8a16 = use_int8_w8a16 - self.use_int4_w4a16 = use_int4_w4a16 - self.per_channel_quant = per_channel_quant - self.block_shape = block_shape self.allow_deep_gemm = allow_deep_gemm # BatchedTritonKernel doesn't support block quantization # at the moment. self.batched_triton_experts = BatchedTritonExperts( max_num_tokens=self.max_num_tokens, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a8=self.use_int8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - per_channel_quant=self.per_channel_quant, - block_shape=self.block_shape, world_size=self.world_size, - dp_size=self.dp_size) if self.block_shape is None else None + dp_size=self.dp_size, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_act_token_quant=self.per_act_token_quant, + block_shape=self.block_shape, + ) if self.block_shape is None else None + + is_fp8_128_block_quantized = ( + use_fp8_w8a8 and self.block_shape + == BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE) - is_fp8_128_block_quantized = (self.use_fp8_w8a8 - and self.block_shape is not None - and len(self.block_shape) == 2 and all( - [b == 128 - for b in self.block_shape])) self.batched_deep_gemm_experts = BatchedDeepGemmExperts( max_num_tokens=self.max_num_tokens, world_size=self.world_size, @@ -67,12 +59,31 @@ def __init__(self, assert (self.batched_deep_gemm_experts is not None or self.batched_triton_experts is not None) + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + if self.batched_triton_experts is not None: + assert (self.batched_deep_gemm_experts is None + or self.batched_deep_gemm_experts.activation_formats + == self.batched_triton_experts.activation_formats) + return self.batched_triton_experts.activation_formats + else: + assert self.batched_deep_gemm_experts is not None + return self.batched_deep_gemm_experts.activation_formats + def supports_chunking(self) -> bool: bdge = self.batched_deep_gemm_experts bte = self.batched_triton_experts return ((bdge is None or bdge.supports_chunking()) and (bte is None or bte.supports_chunking())) + def supports_expert_map(self) -> bool: + bdge = self.batched_deep_gemm_experts + bte = self.batched_triton_experts + return ((bdge is None or bdge.supports_expert_map()) + and (bte is None or bte.supports_expert_map())) + def workspace_shapes( self, a: torch.Tensor, @@ -87,7 +98,8 @@ def workspace_shapes( # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. - if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None: + if self.allow_deep_gemm: + assert self.batched_deep_gemm_experts is not None return self.batched_deep_gemm_experts.workspace_shapes( a, aq, M, N, K, topk, global_num_experts, local_num_experts) else: diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py new file mode 100644 index 00000000000..30ac2042821 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -0,0 +1,459 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from compressed_tensors.quantization import (QuantizationArgs, + QuantizationStrategy, + QuantizationType) + +import vllm.envs as envs +from vllm.config import ParallelConfig +from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +logger = init_logger(__name__) + + +def _get_quant_config_quantization_args( + quant_config: Optional[QuantizationConfig], + prop_name: str, +) -> Optional[QuantizationArgs]: + if (quant_config is not None and hasattr(quant_config, 'target_scheme_map') + and "Linear" in quant_config.target_scheme_map and + "input_activations" in quant_config.target_scheme_map["Linear"]): + return quant_config.target_scheme_map["Linear"].get(prop_name) + else: + return None + + +def get_quant_config_input_quant( + quant_config: Optional[QuantizationConfig] +) -> Optional[QuantizationArgs]: + return _get_quant_config_quantization_args(quant_config, + "input_activations") + + +def get_quant_config_weight_quant( + quant_config: Optional[QuantizationConfig] +) -> Optional[QuantizationArgs]: + return _get_quant_config_quantization_args(quant_config, "weights") + + +def get_config_dtype_str( + act_dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, +) -> Optional[str]: + if use_fp8_w8a8: + return "fp8_w8a8" + elif use_int8_w8a8: + return "int8_w8a8" + elif use_int8_w8a16: + return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w4a16" + elif act_dtype == torch.float: + # avoiding cases where kernel fails when float32 MoE + # use fp16/bfloat16 configs + return "float32" + return None + + +# TODO (bnell): use scalar_type instead of bools? +def get_config_quant_dtype( + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, +) -> Optional[torch.dtype]: + if use_fp8_w8a8: + return torch.float8_e4m3fn + elif use_int8_w8a8 or use_int8_w8a16: + return torch.int8 + return None + + +@dataclass +class FusedMoEQuantConfig: + # The post quantization activation type. + quant_dtype: Optional[torch.dtype] = None + config_dtype_str: Optional[str] = None + per_act_token_quant: bool = False + per_out_ch_quant: bool = False + block_shape: Optional[list[int]] = None + + # TODO: add col major flag? + # add detailed quant info for input, intermediates, weights, etc? + + @property + def use_fp8_w8a8(self) -> bool: + return self.config_dtype_str == "fp8_w8a8" + + @property + def use_int8_w8a8(self) -> bool: + return self.config_dtype_str == "int8_w8a8" + + @property + def use_int8_w8a16(self) -> bool: + return self.config_dtype_str == "int8_w8a16" + + @property + def use_int4_w4a16(self) -> bool: + return self.config_dtype_str == "int4_w4a16" + + @staticmethod + def make( + act_dtype: Optional[torch.dtype] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_act_token_quant: bool = False, + per_out_ch_quant: bool = False, + block_shape: Optional[list[int]] = None, + ) -> "FusedMoEQuantConfig": + """ + - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - per_act_token_quant (bool): TODO + Defaults to False. + - per_out_ch_quant (bool): TODO + Defaults to False. + - block_shape: (Optional[list[int]]): Optional block size for block-wise + quantization. + """ + quant_dtype = get_config_quant_dtype( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16 + ) + config_dtype_str = get_config_dtype_str( + act_dtype if act_dtype is not None else torch.bfloat16, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16 + ) + return FusedMoEQuantConfig( + quant_dtype, + config_dtype_str, + per_act_token_quant, + per_out_ch_quant, + block_shape, + ) + + +@dataclass +class FusedMoEParallelConfig: + tp_size: int + dp_size: int + ep_size: int + tp_rank: int + dp_rank: int + ep_rank: int + world_size: int + + use_ep: bool # whether to use EP or not + + @property + def use_all2all_kernels(self): + return self.dp_size > 1 and self.use_ep + + @property + def use_pplx_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "pplx") + + @property + def use_deepep_ht_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput") + + @property + def use_deepep_ll_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") + + @staticmethod + def make(tp_size_: int, dp_size_: int, world_size_: int, + vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": + """ + Determine MoE parallel configuration. Based on the input tp_size_, + dp_size_, ep_size_ and vllm's parallel config, determine what + level's of parallelism to use in the fused moe layer. + + Args: + tp_size_ (int): tp_size passed into the FusedMoE constructor. + dp_size_ (int): dp_size passed into the FusedMoE constructor. + ep_size_ (int): ep_size passed into the FusedMoE constructor. + world_size_ (int): the world size of the current All2All manager. + vllm_parallel_config (ParallelConfig): vllm's parallel config + object. + + Examples: + When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1, + we simply return the sizes unaltered and the ranks set to 0. + + Expert Parallelism is considered only when either dp_size_ or tp_size_ + is non trivial. + + When TP = 2, DP = 1 and EP = False, the configuration on different + devices, + - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // + legend : {size, rank} + - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0} + - Comment : Tensors are sharded across 2 devices. + + When TP = 1, DP = 2 and EP = False, the configuration on different + devices, + - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0} + - device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded + across 2 decvices. + + When TP = 2, DP = 2 and EP = False, the configuration on different + devices, + - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0} + - device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0} + - device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0} + - device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded + across 4 devices. + + When, TP = 2, DP = 1 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1} + - Comment: The experts are split between the 2 devices. + + When, TP = 1, DP = 2 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1} + - Comment: There are 2 engine instances and the experts are split + between the 2 devices. + + When TP = 2, DP = 2 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0} + - device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1} + - device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2} + - device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3} + - Comment: There are 2 engine instances and the experts are split + between the 4 devices. + """ + + def flatten_tp_across_dp(dp_rank: int): + tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank() + # There are actually dp_size_ * tp_size_ devices. Update tp_size + # and tp_rank so we shard across all devices. + tp_size = dp_size_ * tp_size_ + tp_rank = dp_rank * tp_size_ + tp_rank + return tp_size, tp_rank + + use_ep = (dp_size_ * tp_size_ > 1 + and vllm_parallel_config.enable_expert_parallel) + + dp_size = dp_size_ + dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 + tp_size, tp_rank = flatten_tp_across_dp(dp_rank) + + if not use_ep: + return FusedMoEParallelConfig(tp_size=tp_size, + tp_rank=tp_rank, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=1, + ep_rank=0, + world_size=world_size_, + use_ep=False) + # DP + EP / TP + EP / DP + TP + EP + assert use_ep + # In EP, each device owns a set of experts fully. There is no tensor + # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. + ep_size = tp_size + ep_rank = tp_rank + return FusedMoEParallelConfig(tp_size=1, + tp_rank=0, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=ep_size, + ep_rank=ep_rank, + world_size=world_size_, + use_ep=True) + + +# Adapted from pplx-kernels tests/all_to_all_utils.py +@dataclass +class FusedMoEConfig: + num_experts: int + experts_per_token: int + hidden_dim: int + + num_local_experts: int + moe_parallel_config: FusedMoEParallelConfig + + # The activation type. + in_dtype: torch.dtype + + quant_config: Optional[FusedMoEQuantConfig] = None + + max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE + + def __post_init__(self): + if self.dp_size > 1: + logger.debug("Using FusedMoEConfig::max_num_tokens=%d", + self.max_num_tokens) + + @property + def quant_dtype(self) -> Optional[torch.dtype]: + if self.quant_config is not None: + return self.quant_config.quant_dtype + else: + return None + + @property + def block_shape(self) -> Optional[list[int]]: + if self.quant_config is not None: + return self.quant_config.block_shape + else: + return None + + @property + def per_act_token_quant(self) -> bool: + if self.quant_config is not None: + return self.quant_config.per_act_token_quant + else: + return False + + @property + def per_out_ch_quant(self) -> bool: + if self.quant_config is not None: + return self.quant_config.per_out_ch_quant + else: + return False + + @property + def tp_size(self): + return self.moe_parallel_config.tp_size + + @property + def dp_size(self): + return self.moe_parallel_config.dp_size + + @property + def ep_size(self): + return self.moe_parallel_config.ep_size + + @property + def world_size(self): + return self.moe_parallel_config.world_size + + @property + def tp_rank(self): + return self.moe_parallel_config.tp_rank + + @property + def dp_rank(self): + return self.moe_parallel_config.dp_rank + + @property + def ep_rank(self): + return self.moe_parallel_config.ep_rank + + @property + def use_ep(self): + return self.moe_parallel_config.use_ep + + @property + def use_pplx_kernels(self): + return self.moe_parallel_config.use_pplx_kernels + + @property + def use_deepep_ht_kernels(self): + return self.moe_parallel_config.use_deepep_ht_kernels + + @property + def use_deepep_ll_kernels(self): + return self.moe_parallel_config.use_deepep_ll_kernels + + @staticmethod + def make( + num_experts: int, + experts_per_token: int, + hidden_dim: int, + num_local_experts: int, + moe_parallel_config: FusedMoEParallelConfig, + in_dtype: torch.dtype, + max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE, + quant_config: Optional[Union[FusedMoEQuantConfig, + QuantizationConfig]] = None + ) -> "FusedMoEConfig": + + _quant_config: Optional[FusedMoEQuantConfig] = None + + if quant_config is not None and isinstance(quant_config, + QuantizationConfig): + block_shape = quant_config.get("weight_block_size", None) + per_act_token_quant = False + per_out_ch_quant = False + quant_dtype: Optional[torch.dtype] = None + + input_quant = get_quant_config_input_quant(quant_config) + weight_quant = get_quant_config_input_quant(quant_config) + + if input_quant is not None: + per_act_token_quant = (input_quant.strategy + == QuantizationStrategy.TOKEN + if input_quant is not None else False) + + if input_quant.num_bits == 8: + if input_quant.type == QuantizationType.FLOAT: + quant_dtype = torch.float8_e4m3fn + elif input_quant.type == QuantizationType.INT: + quant_dtype = torch.int8 + + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + if quant_dtype is None and isinstance(quant_config, Fp8Config): + quant_dtype = torch.float8_e4m3fn + + if weight_quant is not None: + per_out_ch_quant = ( + weight_quant.strategy == QuantizationStrategy.CHANNEL) + + assert quant_dtype is not None + + _quant_config = FusedMoEQuantConfig( + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, + block_shape=block_shape, + ) + else: + _quant_config = quant_config + + return FusedMoEConfig( + num_experts=num_experts, + experts_per_token=experts_per_token, + hidden_dim=hidden_dim, + num_local_experts=num_local_experts, + moe_parallel_config=moe_parallel_config, + in_dtype=in_dtype, + quant_config=_quant_config, + max_num_tokens=max_num_tokens, + ) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 3f9ceac8b6e..e0c01bbf074 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -7,6 +7,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache @@ -41,24 +42,24 @@ def run_cutlass_moe_fp8( assert w1.dtype == torch.float8_e4m3fn assert w2.dtype == torch.float8_e4m3fn if expert_num_tokens is None: - assert a1q.shape[1] == w1.shape[2], "Hidden size mismatch w1" + assert a1q.size(1) == w1.size(2), "Hidden size mismatch w1" else: - assert a1q.shape[2] == w1.shape[2], "Hidden size mismatch w1" - assert w1.shape[1] == w2.shape[2] * 2, "Hidden size mismatch w2" - assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ - 1] == w1.shape[1], "W1 scale shape mismatch" - assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ - 1] == w2.shape[1], "W2 scale shape mismatch" - assert w1.shape[0] == w2.shape[0], "Expert number mismatch" - assert a1q_scale is None or a1q_scale.dim( - ) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[ - 0], "Input scale shape mismatch" - assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch" - assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" - assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" - assert a2_scale is None or a2_scale.dim( - ) == 0 or a2_scale.shape[0] == 1 or a2_scale.shape[0] == a1q.shape[ - 0], "Intermediate scale shape mismatch" + assert a1q.size(2) == w1.size(2), "Hidden size mismatch w1" + assert w1.size(1) == w2.size(2) * 2, "Hidden size mismatch w2" + assert w1_scale.dim() == 1 or w1_scale.size( + 1) == 1 or w1_scale.shape[1] == w1.size(1), "W1 scale shape mismatch" + assert w2_scale.dim() == 1 or w2_scale.size( + 1) == 1 or w2_scale.shape[1] == w2.size(1), "W2 scale shape mismatch" + assert w1.size(0) == w2.size(0), "Expert number mismatch" + assert a1q_scale is None or a1q_scale.dim() == 0 or a1q_scale.size( + 0) == 1 or a1q_scale.size( + 0) == a1q.shape[0], "Input scale shape mismatch" + assert w1.size(0) == w2.size(0), "Weights expert number mismatch" + assert w1.size(0) == w1_scale.size(0), "w1 scales expert number mismatch" + assert w1.size(0) == w2_scale.size(0), "w2 scales expert number mismatch" + assert a2_scale is None or a2_scale.dim() == 0 or a2_scale.size( + 0) == 1 or a2_scale.size( + 0) == a1q.shape[0], "Intermediate scale shape mismatch" assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" if expert_map is not None: assert expert_num_tokens is None @@ -75,12 +76,12 @@ def run_cutlass_moe_fp8( # their tokens are already contiguous for each expert as a result of # the dispatch function. - M = a1q.shape[0] # non batched expert M - padded_M = a1q.shape[1] # batched expert M + M = a1q.size(0) # non batched expert M + padded_M = a1q.size(1) # batched expert M _, K, N = w2.shape device = a1q.device - assert w1.shape[2] == K + assert w1.size(2) == K assert global_num_experts != -1 assert a1q_scale is not None @@ -91,8 +92,8 @@ def run_cutlass_moe_fp8( else: local_topk_ids = topk_ids - topk = local_topk_ids.shape[1] - local_E = w1.shape[0] + topk = local_topk_ids.size(1) + local_E = w1.size(0) if use_batched_format: assert expert_num_tokens is not None @@ -111,10 +112,10 @@ def run_cutlass_moe_fp8( problem_sizes2, expert_num_tokens, local_E, padded_M, N, K) - w1_scale = w1_scale.reshape(w1_scale.shape[0], -1) - w2_scale = w2_scale.reshape(w2_scale.shape[0], -1) - a1q = a1q.reshape(-1, a1q.shape[2]) - a1q_scale = a1q_scale.reshape(-1, a1q_scale.shape[2]).contiguous() + w1_scale = w1_scale.reshape(w1_scale.size(0), -1) + w2_scale = w2_scale.reshape(w2_scale.size(0), -1) + a1q = a1q.reshape(-1, a1q.size(2)) + a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous() else: expert_offsets = torch.empty((global_num_experts + 1), @@ -151,19 +152,19 @@ def run_cutlass_moe_fp8( a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale expert_offsets = expert_offsets[:-1] - ab_strides1 = torch.full((w1.shape[0], ), + ab_strides1 = torch.full((w1.size(0), ), K, device=device, dtype=torch.int64) - c_strides1 = torch.full((w1.shape[0], ), + c_strides1 = torch.full((w1.size(0), ), 2 * N, device=device, dtype=torch.int64) - ab_strides2 = torch.full((w1.shape[0], ), + ab_strides2 = torch.full((w1.size(0), ), N, device=device, dtype=torch.int64) - c_strides2 = torch.full((w1.shape[0], ), + c_strides2 = torch.full((w1.size(0), ), K, device=device, dtype=torch.int64) @@ -202,26 +203,43 @@ def run_cutlass_moe_fp8( # TODO (bnell): split class batched vs. non-batched? +# maybe remove need for passing aq to workspace_shapes class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, max_experts_per_worker: int, - out_dtype: torch.dtype, - per_act_token: bool, - per_out_ch: bool, + out_dtype: Optional[torch.dtype], + per_act_token_quant: bool, + per_out_ch_quant: bool, + block_shape: Optional[list[int]] = None, use_batched_format: bool = False, ): - super().__init__() + super().__init__( + FusedMoEQuantConfig.make( + use_fp8_w8a8=True, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, + block_shape=block_shape, + )) + assert max_experts_per_worker > 0 self.max_experts_per_worker = max_experts_per_worker self.out_dtype = out_dtype - self.per_act_token = per_act_token - self.per_out_ch = per_out_ch self.use_batched_format = use_batched_format + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) + def supports_chunking(self) -> bool: return not self.use_batched_format + def supports_expert_map(self) -> bool: + return not self.use_batched_format + def workspace_shapes( self, a: torch.Tensor, @@ -237,7 +255,7 @@ def workspace_shapes( workspace2: tuple[int, ...] = () output: tuple[int, ...] = () if self.use_batched_format: - padded_M = aq.shape[1] + padded_M = aq.size(1) workspace1 = (self.max_experts_per_worker, padded_M, max(N, K)) workspace2 = (self.max_experts_per_worker, padded_M, (N // 2)) output = (self.max_experts_per_worker, padded_M, K) @@ -245,7 +263,8 @@ def workspace_shapes( workspace1 = (M * topk, max(2 * N, K)) workspace2 = (M * topk, N) output = (M * topk, K) - return (workspace1, workspace2, output, self.out_dtype) + return (workspace1, workspace2, output, + self.out_dtype if self.out_dtype is not None else a.dtype) def apply( self, @@ -270,13 +289,14 @@ def apply( assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" activation_callable = lambda i, o: self.activation(activation, i, o) - run_cutlass_moe_fp8(output, hidden_states, w1, w2, topk_ids, - activation_callable, global_num_experts, - expert_map, w1_scale, w2_scale, a1q_scale, - a2_scale, workspace13, workspace2, - expert_num_tokens, self.out_dtype, - self.per_act_token, self.per_out_ch, - self.use_batched_format) + in_dtype = hidden_states.dtype + run_cutlass_moe_fp8( + output, hidden_states, w1, w2, topk_ids, activation_callable, + global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, + a2_scale, workspace13, workspace2, expert_num_tokens, + self.out_dtype if self.out_dtype is not None else in_dtype, + self.per_act_token_quant, self.per_out_ch_quant, + self.use_batched_format) def cutlass_moe_fp8( @@ -332,20 +352,18 @@ def cutlass_moe_fp8( """ per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) - per_out_ch = w1_scale.numel() != w1_q.shape[0] + per_out_ch = w1_scale.numel() != w1_q.size(0) - out_dtype = a.dtype + num_experts = global_num_experts if global_num_experts != -1 else w1_q.size( + 0) fn = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP( - quant_dtype=torch.float8_e4m3fn, - per_channel_quant=per_act_token, - ), + MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp8( - max_experts_per_worker=global_num_experts, - out_dtype=out_dtype, - per_act_token=per_act_token, - per_out_ch=per_out_ch, + max_experts_per_worker=num_experts, + out_dtype=a.dtype, + per_act_token_quant=per_act_token, + per_out_ch_quant=per_out_ch, use_batched_format=False, ), ) @@ -358,7 +376,7 @@ def cutlass_moe_fp8( topk_ids, False, activation, - global_num_experts if global_num_experts != -1 else w1_q.size(0), + num_experts, expert_map, w1_scale, w2_scale, @@ -425,11 +443,11 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, assert (m == m_a), "input shape mismatch" assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" - assert (topk_weights.shape[0] == m and topk_ids.shape[0] + assert (topk_weights.size(0) == m and topk_ids.size(0) == m), ("topk must be provided for each row of a") out_dtype = a.dtype - num_topk = topk_ids.shape[1] + num_topk = topk_ids.size(1) expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) blockscale_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) @@ -463,7 +481,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, out_dtype, device) del rep_a_fp4, rep_a_blockscale # hidden size dimension is split to one halfpytho sized tensor. - intermediate = torch.empty((m * num_topk, w1_fp4.shape[1] // 2), + intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2), device=device, dtype=out_dtype) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index b4473b90738..09dcd174e15 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -8,6 +8,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( _moe_permute) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( @@ -48,7 +49,7 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, M = hidden_states.size(0) _, K, N = w2.size() if not _valid_deep_gemm_shape(M, N, K): - logger.debug("DeepGemm disabled: unalinged problem size.") + logger.debug("DeepGemm disabled: unaligned problem size.") return False if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): @@ -67,16 +68,31 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self): - super().__init__() - self.block_shape = deep_gemm_block_shape() + super().__init__( + FusedMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_act_token_quant=False, + block_shape=deep_gemm_block_shape(), + )) + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) def supports_chunking(self) -> bool: return True + def supports_expert_map(self) -> bool: + return True + def workspace_shapes( self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + assert self.block_shape is not None # We use global_num_experts due to how moe_align_block_size handles # expert_maps. num_experts = global_num_experts @@ -109,6 +125,7 @@ def apply( expert_num_tokens: Optional[torch.Tensor], ): import deep_gemm as dg + assert self.block_shape is not None a1q = hidden_states _, N, K = w1.size() @@ -215,8 +232,7 @@ def deep_gemm_moe_fp8( - torch.Tensor: The bfloat16 output tensor after applying the MoE layer. """ fn = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(quant_dtype=torch.float8_e4m3fn, - block_shape=deep_gemm_block_shape()), + MoEPrepareAndFinalizeNoEP(), DeepGemmExperts(), ) return fn( diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 8c21d8aa53a..da8921368d6 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -6,6 +6,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) @@ -15,22 +16,14 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): Prepare/Finalize using DeepEP High-Throughput kernels. """ - def __init__(self, - buffer: deep_ep.Buffer, - world_size: int, - rank: int, - dp_size: int, - rank_expert_offset: int, - quant_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None): + def __init__(self, buffer: deep_ep.Buffer, world_size: int, rank: int, + dp_size: int, rank_expert_offset: int): super().__init__() self.buffer = buffer self.world_size = world_size self.rank = rank self.dp_size = dp_size self.rank_expert_offset = rank_expert_offset - self.quant_dtype = quant_dtype - self.block_shape = block_shape # The dispatch function returns a handle that the combine function # requires. We store the handle here so it is available to the # combine function. @@ -39,6 +32,10 @@ def __init__(self, # From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164 self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160] + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + def max_num_tokens_per_rank(self) -> Optional[int]: return None @@ -55,13 +52,6 @@ def _get_combine_config(self) -> Optional[deep_ep.Config]: return None return deep_ep.Buffer.get_combine_config(self.dp_size) - def _do_quant(self, tokens: torch.Tensor, - token_scales: Optional[torch.Tensor], per_act_token: bool): - tokens, token_scales = moe_kernel_quantize_input( - tokens, token_scales, self.quant_dtype, per_act_token, - self.block_shape) - return tokens, token_scales - def _do_dispatch(self, tokens: torch.Tensor, token_scales: Optional[torch.Tensor], rank_topk_ids: torch.Tensor, @@ -130,43 +120,52 @@ def prepare( a1: torch.Tensor, a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], - rank_topk_weights: torch.Tensor, - rank_topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: if apply_router_weight_on_input: - topk = rank_topk_ids.size(1) + topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 assert topk == 1, ( "apply_router_weight_on_input is only implemented for topk=1") - a1 = a1 * rank_topk_weights.to(a1.dtype) + a1 = a1 * topk_weights.to(a1.dtype) # Check if there is a block_shape / or if we can infer the quantization # schemes from the scales. per_token_quant = None - if all([x is None for x in [self.block_shape, a1_scale, a2_scale] - ]) and self.quant_dtype is not None: + if all([ + x is None + for x in [quant_config.block_shape, a1_scale, a2_scale] + ]) and quant_config.quant_dtype is not None: # Quantization required despite none of the inputs suggesting # quantization. Fallback to per_token_dynamic quant. per_token_quant = True else: - per_token_quant = ((self.block_shape is not None) or + per_token_quant = ((quant_config.block_shape is not None) or (a1_scale is not None and a1_scale.numel() != 1) or (a2_scale is not None and a2_scale.numel() != 1)) if per_token_quant: - a1q, a1q_scale = self._do_quant(a1, a1_scale, per_act_token=True) + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + a1_scale, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=False, + block_shape=quant_config.block_shape, + ) (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, expert_topk_weights) = self._do_dispatch( tokens=a1q, token_scales=a1q_scale, - rank_topk_ids=rank_topk_ids, - rank_topk_weights=rank_topk_weights, + rank_topk_ids=topk_ids, + rank_topk_weights=topk_weights, num_experts=num_experts) else: # DeepEP kernels only support dispatching per-token-quant @@ -175,15 +174,18 @@ def prepare( expert_topk_weights) = self._do_dispatch( tokens=a1, token_scales=None, - rank_topk_ids=rank_topk_ids, - rank_topk_weights=rank_topk_weights, + rank_topk_ids=topk_ids, + rank_topk_weights=topk_weights, num_experts=num_experts) # quantize now expert_x_scale = None if expert_x.numel() != 0: - expert_x, expert_x_scale = self._do_quant(expert_x, - a1_scale, - per_act_token=False) + expert_x, expert_x_scale = moe_kernel_quantize_input( + expert_x, + a1_scale, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=False, + block_shape=quant_config.block_shape) return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, expert_topk_weights) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 3484a7a8a49..c0057061208 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -5,6 +5,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) @@ -25,7 +26,7 @@ def dequant_fp8(expert_x_fp8: torch.Tensor, expert_x_fp32 = expert_x_fp8.to(torch.float32).view( num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE) expert_x_scales = expert_x_scales.view(num_experts, -1, 1) - return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.shape) + return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size()) class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): @@ -39,26 +40,26 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def __init__(self, buffer: deep_ep.Buffer, + max_tokens_per_rank: int, world_size: int, dp_size: int, - max_tokens_per_rank: int, - quant_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None, use_fp8_dispatch: bool = False): super().__init__() self.buffer = buffer + self.max_tokens_per_rank = max_tokens_per_rank self.world_size = world_size self.dp_size = dp_size - self.quant_dtype = quant_dtype - self.block_shape = block_shape - self.max_tokens_per_rank = max_tokens_per_rank self.use_fp8_dispatch = use_fp8_dispatch # The dispatch function returns a handle that the combine function # requires. We store the handle here so it is available to the # combine function. self.handle = None + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + def max_num_tokens_per_rank(self) -> Optional[int]: return self.max_tokens_per_rank @@ -66,12 +67,17 @@ def topk_indices_dtype(self) -> Optional[torch.dtype]: return torch.int64 def _do_quant( - self, x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], - a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], - a1_dtype: torch.dtype + self, + x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + a1_dtype: torch.dtype, + quant_dtype: Optional[torch.dtype], + per_act_token_quant: bool, + block_shape: Optional[list[int]], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - block_k = self.block_shape[1] if self.block_shape is not None else None + block_k = block_shape[1] if block_shape is not None else None if self.use_fp8_dispatch: if block_k == DEEPEP_QUANT_BLOCK_SIZE: # DeepEP kernels did the quantization for us. @@ -84,30 +90,42 @@ def _do_quant( assert isinstance(x, torch.Tensor) + # TODO (bnell): # Check if there is a block_shape / or if we can infer the quantization # schemes from the scales. - per_token_quant = None - if all([v is None for v in [self.block_shape, a1_scale, a2_scale] - ]) and self.quant_dtype is not None: + _per_act_token_quant = False + if all([v is None for v in [block_shape, a1_scale, a2_scale] + ]) and quant_dtype is not None: # Quantization required despite none of the inputs suggesting # quantization. Fallback to per_token_dynamic quant. - per_token_quant = True + #print(f"DYNAMIC") + _per_act_token_quant = True else: - per_token_quant = ((self.block_shape is not None) or - (a1_scale is not None and a1_scale.numel() != 1) - or (a2_scale is not None - and a2_scale.numel() != 1)) + _per_act_token_quant = ( + (block_shape is not None) + or (a1_scale is not None and a1_scale.numel() != 1) + or (a2_scale is not None and a2_scale.numel() != 1)) + #print(f"{block_shape} {a1_scale} {a2_scale}") + + # assert per_act_token_quant == ( + # (block_shape is not None) + # or (a1_scale is not None and a1_scale.numel() != 1) + # or (a2_scale is not None and a2_scale.numel() != 1)) + + # TODO(bnell) + assert per_act_token_quant == _per_act_token_quant, \ + f"{per_act_token_quant} == {_per_act_token_quant}" num_experts, max_tokens, hidden_dim = x.size() # TODO (varun): Optimization - Use a batched version of quant x = x.view((-1, hidden_dim)) - x, x_scales = moe_kernel_quantize_input(x, a1_scale, self.quant_dtype, - per_token_quant, - self.block_shape) + x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype, + _per_act_token_quant, + block_shape) x = x.view((num_experts, -1, hidden_dim)) - if per_token_quant: + if _per_act_token_quant: assert x_scales is not None x_scales = x_scales.view(num_experts, max_tokens, -1) @@ -118,11 +136,12 @@ def prepare( a1: torch.Tensor, a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], - rank_topk_weights: torch.Tensor, - rank_topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -142,24 +161,25 @@ def prepare( "low_latency kernels doesn't support dispatching per-token scales") if apply_router_weight_on_input: - topk = rank_topk_ids.size(1) + topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 assert topk == 1, ( "apply_router_weight_on_input is only implemented for topk=1") - a1 = a1 * rank_topk_weights.to(a1.dtype) + a1 = a1 * topk_weights.to(a1.dtype) # Dispatch expert_x, expert_num_tokens, self.handle, event, hook = \ self.buffer.low_latency_dispatch(a1, - rank_topk_ids, + topk_ids, self.max_tokens_per_rank, num_experts, use_fp8=self.use_fp8_dispatch, async_finish=False, return_recv_hook=False) - expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale, - a1.dtype) + expert_x, expert_x_scale = self._do_quant( + expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype, + quant_config.per_act_token_quant, quant_config.block_shape) return (expert_x, expert_x_scale, expert_num_tokens, None, None) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index a12cfafd42a..44c62047df0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -8,6 +8,7 @@ import triton.language as tl import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_moe import ( get_config_dtype_str, try_get_optimal_moe_config) from vllm.model_executor.layers.fused_moe.utils import ( @@ -317,8 +318,8 @@ def invoke_moe_batched_triton_kernel( expert_num_tokens: torch.Tensor, # [E] compute_type: tl.dtype, # Quantization data - A_scale: torch.Tensor, - B_scale: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], B_zp: torch.Tensor, # Quantization schemes use_fp8_w8a8: bool, @@ -387,14 +388,23 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): that the PPLX dispatch/combine kernels use. """ - def __init__(self, max_num_tokens: int, world_size: int, dp_size: int, - rank: int): + def __init__( + self, + max_num_tokens: int, + world_size: int, + dp_size: int, + rank: int, + ): super().__init__() self.world_size = world_size self.dp_size = dp_size self.rank = rank self.max_num_tokens = max_num_tokens + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + def max_num_tokens_per_rank(self) -> Optional[int]: return self.max_num_tokens @@ -411,6 +421,7 @@ def prepare( num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: assert a1.dim() == 2 @@ -435,22 +446,33 @@ def prepare( num_local_experts = num_experts // self.world_size + if quant_config.quant_dtype is None: + b_type = a1.dtype + else: + b_type = quant_config.quant_dtype + b_a1 = torch.zeros( (num_local_experts, self.max_num_tokens, hidden_dim), - dtype=a1.dtype, + dtype=b_type, device=a1.device) + b_a1_scale = None + + assert quant_config.quant_dtype is None, "quantization NYI" + first_expert = num_local_experts * self.rank last_expert = first_expert + num_local_experts for expert_id in range(first_expert, last_expert): topks = torch.any(topk_ids == expert_id, dim=1).flatten() rows = torch.count_nonzero(topks.flatten()) - b_a1[expert_id - - first_expert, :rows, :] = a1[:topks.numel()][topks] - tokens_per_expert[expert_id - first_expert] = rows + idx = expert_id - first_expert + b_a1[idx, :rows, :] = a1[:topks.numel()][topks] + tokens_per_expert[idx] = rows + + assert b_a1_scale is None or b_a1_scale.ndim == 3 - return b_a1, a1_scale, tokens_per_expert, None, None + return b_a1, b_a1_scale, tokens_per_expert, None, None def finalize( self, @@ -480,7 +502,8 @@ def finalize( output[topks] = output[topks] + rhs -class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): +# XXXX BatchedNaiveExperts +class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): """ A reference MoE expert class that operates on expert batched format, i.e. E x max_num_tokens x K. This is the format that the pplx @@ -492,27 +515,30 @@ def __init__( max_num_tokens: int, world_size: int, dp_size: int, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - block_shape: Optional[list[int]] = None, - block_m: Optional[int] = None, + quant_config: Optional[FusedMoEQuantConfig] = None, ): - super().__init__() - assert block_shape is None - assert block_m is None - assert not use_fp8_w8a8, "NYI" - assert not use_int8_w8a8, "NYI" - assert not use_int8_w8a16, "NYI" - assert not use_int4_w4a16, "NYI" + super().__init__(quant_config) + assert quant_config is None or (not quant_confg.use_fp8_w8a8 and + not quant_confg.use_int8_w8a8 and + not quant_confg.use_int8_w8a16 and + not quant_confg.use_int4_w4a16), "NYI" self.max_num_tokens = max_num_tokens self.world_size = world_size self.dp_size = dp_size + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) + def supports_chunking(self) -> bool: return False + def supports_expert_map(self) -> bool: + return False + def workspace_shapes( self, a: torch.Tensor, @@ -590,34 +616,32 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - max_num_tokens: Optional[int] = None, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, - block_shape: Optional[list[int]] = None, - world_size: int = 1, - dp_size: int = 1, + max_num_tokens: int, + world_size: int, + dp_size: int, + quant_config: Optional[FusedMoEQuantConfig] = None, ): - super().__init__() - self.use_fp8_w8a8 = use_fp8_w8a8 - self.use_int8_w8a8 = use_int8_w8a8 - self.use_int4_w4a16 = use_int4_w4a16 - self.use_int8_w8a16 = use_int8_w8a16 - self.block_shape = block_shape - self.per_channel_quant = per_channel_quant + super().__init__(quant_config) + assert quant_config is None or (not quant_confg.use_int8_w8a8 and + not quant_confg.use_int8_w8a16 and + not quant_confg.use_int4_w4a16), "NYI" self.max_num_tokens = max_num_tokens self.world_size = world_size self.dp_size = dp_size - assert not use_int8_w8a8, "NYI" - assert not use_int4_w4a16, "NYI" - assert self.block_shape is None, "NYI" + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) def supports_chunking(self) -> bool: return False + def supports_expert_map(self) -> bool: + return False + def workspace_shapes( self, a: torch.Tensor, @@ -632,8 +656,7 @@ def workspace_shapes( assert a.dim() == 2 num_dp = self.world_size // self.dp_size num_experts = local_num_experts - max_num_tokens = a.size( - 0) if self.max_num_tokens is None else self.max_num_tokens + max_num_tokens = self.max_num_tokens workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2)) output = (num_experts, max_num_tokens * num_dp, K) @@ -660,7 +683,7 @@ def apply( expert_num_tokens: Optional[torch.Tensor], ): # Check constraints. - if self.use_int4_w4a16: + if self.quant_config.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( "Hidden size mismatch") else: @@ -682,16 +705,11 @@ def apply( assert w1.size(0) == E assert w2.size(0) == E - config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - dtype=hidden_states.dtype) - config = try_get_optimal_moe_config( w1.size(), w2.size(), top_k_num, - config_dtype, + quant_config.config_dtype_str, max_num_tokens, block_shape=self.block_shape, ) @@ -716,7 +734,7 @@ def apply( intermediate_cache2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2)) - if self.use_fp8_w8a8: + if self.quant_config.use_fp8_w8a8: intermediate_cache1.fill_(0) # MM1 @@ -728,12 +746,14 @@ def apply( A_scale=a1q_scale, B_scale=w1_scale, B_zp=w1_zp, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, + use_fp8_w8a8=self.quant_config.use_fp8_w8a8, + use_int8_w8a16=self.quant_config.use_int8_w8a16, + use_int4_w4a16=self.quant_config.use_int4_w4a16, config=config, block_shape=self.block_shape) + intermediate_cache2.fill_(0) + # TODO: would be nice to use expert_num_tokens here to reduce # garbage compute self.activation(activation, intermediate_cache2.view(-1, N // 2), @@ -745,8 +765,8 @@ def apply( qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( A=intermediate_cache2, A_scale=a2_scale, - qtype=torch.float8_e4m3fn if self.use_fp8_w8a8 else None, - per_channel_quant=self.per_channel_quant, + quant_dtype=torch.float8_e4m3fn if self.quant_config.use_fp8_w8a8 else None, + per_act_token_quant=self.per_act_token_quant, block_shape=self.block_shape) qintermediate_cache2 = qintermediate_cache2.view( @@ -760,8 +780,8 @@ def apply( A_scale=a2q_scale, B_scale=w2_scale, B_zp=w2_zp, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, + use_fp8_w8a8=self.quant_config.use_fp8_w8a8, + use_int8_w8a16=self.quant_config.use_int8_w8a16, + use_int4_w4a16=self.quant_config.use_int4_w4a16, config=config, block_shape=self.block_shape) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 437e80696ac..69d32aa597c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -12,6 +12,8 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, get_config_quant_dtype) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm, deep_gemm_moe_fp8) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( @@ -462,36 +464,39 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) -def invoke_fused_moe_kernel(A: torch.Tensor, - B: torch.Tensor, - C: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - B_zp: Optional[torch.Tensor], - topk_weights: Optional[torch.Tensor], - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, - top_k: int, - config: dict[str, Any], - compute_type: tl.dtype, - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - per_channel_quant: bool, - block_shape: Optional[list[int]] = None) -> None: +def invoke_fused_moe_kernel( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], + topk_weights: Optional[torch.Tensor], + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: dict[str, Any], + compute_type: tl.dtype, + quant_config: FusedMoEQuantConfig, +) -> None: assert topk_weights is not None or not mul_routed_weight assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 + use_fp8_w8a8 = quant_config.use_fp8_w8a8 + use_int8_w8a8 = quant_config.use_int8_w8a8 + use_int8_w8a16 = quant_config.use_int8_w8a16 + use_int4_w4a16 = quant_config.use_int4_w4a16 + block_shape = quant_config.block_shape + if use_fp8_w8a8 or use_int8_w8a8: assert B_scale is not None - assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0]) - == B_scale.shape[-2]) - assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1]) - == B_scale.shape[-1]) + assert (block_shape is None + or triton.cdiv(B.size(-2), block_shape[0]) == B_scale.size(-2)) + assert (block_shape is None + or triton.cdiv(B.size(-1), block_shape[1]) == B_scale.size(-1)) elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None @@ -500,19 +505,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor, assert A_scale is None assert B_scale is None - M = A.shape[0] + M = A.size(0) num_tokens = M * top_k - EM = sorted_token_ids.shape[0] - if A.shape[0] < config["BLOCK_SIZE_M"]: + EM = sorted_token_ids.size(0) + if A.size(0) < config["BLOCK_SIZE_M"]: # optimize for small batch_size. # We assume that top_ids of each token is unique, so # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, # and we can skip some invalid blocks. - EM = min(sorted_token_ids.shape[0], - A.shape[0] * top_k * config['BLOCK_SIZE_M']) + EM = min(sorted_token_ids.size(0), + A.size(0) * top_k * config['BLOCK_SIZE_M']) grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( - B.shape[1], META['BLOCK_SIZE_N']), ) + B.size(1), META['BLOCK_SIZE_N']), ) if (use_int8_w8a16 or use_int4_w4a16) and \ block_shape is not None and block_shape[1] > 0: @@ -522,16 +527,16 @@ def invoke_fused_moe_kernel(A: torch.Tensor, use_moe_wna16_cuda = should_moe_wna16_use_cuda( num_valid_tokens=num_tokens, group_size=block_shape[1], - num_experts=B.shape[0], + num_experts=B.size(0), bit=4 if use_int4_w4a16 else 8) config = config.copy() config.update( get_moe_wna16_block_config(config=config, use_moe_wna16_cuda=use_moe_wna16_cuda, num_valid_tokens=num_tokens, - size_k=A.shape[1], - size_n=B.shape[1], - num_experts=B.shape[1], + size_k=A.size(1), + size_n=B.size(1), + num_experts=B.size(1), group_size=block_shape[1], real_top_k=top_k, block_size_m=config["BLOCK_SIZE_M"])) @@ -556,8 +561,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, sorted_token_ids, expert_ids, num_tokens_post_padded, - B.shape[1], - A.shape[1], + B.size(1), + A.size(1), EM, num_tokens, A.stride(0), @@ -573,7 +578,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B_zp.stride(0) if B_zp is not None else 0, B_zp.stride(2) if B_zp is not None else 0, B_zp.stride(1) if B_zp is not None else 0, - block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, + block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0, group_size=block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, @@ -599,8 +604,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, sorted_token_ids, expert_ids, num_tokens_post_padded, - B.shape[1], - B.shape[2], + B.size(1), + B.size(2), EM, num_tokens, A.stride(0), @@ -818,7 +823,7 @@ def try_get_optimal_moe_config( M: int, is_marlin: bool = False, block_shape: Optional[list[int]] = None, -): +) -> dict[str, int]: from vllm.model_executor.layers.fused_moe import get_config override_config = get_config() if override_config: @@ -873,10 +878,10 @@ def fused_topk( renormalize: bool, indices_type: Optional[torch.dtype] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - assert hidden_states.shape[0] == gating_output.shape[0], ( + assert hidden_states.size(0) == gating_output.size(0), ( "Number of tokens mismatch") - M, _ = hidden_states.shape + M, _ = hidden_states.size() topk_weights = torch.empty(M, topk, @@ -915,7 +920,7 @@ def grouped_topk( e_score_correction_bias: Optional[torch.Tensor] = None ) -> tuple[torch.Tensor, torch.Tensor]: - assert hidden_states.shape[0] == gating_output.shape[0], ( + assert hidden_states.size(0) == gating_output.size(0), ( "Number of tokens mismatch") if scoring_func == "softmax": @@ -925,7 +930,7 @@ def grouped_topk( else: raise ValueError(f"Unsupported scoring function: {scoring_func}") - num_token = scores.shape[0] + num_token = scores.size(0) if e_score_correction_bias is not None: # Store original scores before applying correction bias. We use biased # scores for expert selection but original scores for routing weights @@ -942,7 +947,7 @@ def grouped_topk( group_mask.scatter_(1, group_idx, 1) # [n, n_group] score_mask = group_mask.unsqueeze(-1).expand( num_token, num_expert_group, - scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] + scores.size(-1) // num_expert_group).reshape(num_token, -1) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] @@ -962,38 +967,6 @@ def grouped_topk( return topk_weights.to(torch.float32), topk_ids.to(torch.int32) -def get_config_dtype_str( - dtype: torch.dtype, - use_int4_w4a16: Optional[bool] = False, - use_int8_w8a16: Optional[bool] = False, - use_fp8_w8a8: Optional[bool] = False) -> Optional[str]: - if use_fp8_w8a8: - return "fp8_w8a8" - elif use_int8_w8a16: - return "int8_w8a16" - elif use_int4_w4a16: - return "int4_w4a16" - elif dtype == torch.float: - # avoiding cases where kernel fails when float32 MoE - # use fp16/bfloat16 configs - return "float32" - return None - - -# TODO (bnell): use scalar_type instead of bools? -def get_config_qtype( - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, -) -> Optional[torch.dtype]: - if use_fp8_w8a8: - return torch.float8_e4m3fn - elif use_int8_w8a8: - return torch.int8 - return None - - def inplace_fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -1001,11 +974,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, + quant_config: Optional[FusedMoEQuantConfig] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, @@ -1013,14 +982,11 @@ def inplace_fused_experts(hidden_states: torch.Tensor, w1_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None) -> None: + a2_scale: Optional[torch.Tensor] = None) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, - activation, apply_router_weight_on_input, use_fp8_w8a8, - use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, - per_channel_quant, global_num_experts, expert_map, - w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, - block_shape) + activation, apply_router_weight_on_input, + global_num_experts, expert_map, w1_scale, w2_scale, + w1_zp, w2_zp, a1_scale, a2_scale) def inplace_fused_experts_fake( @@ -1031,11 +997,7 @@ def inplace_fused_experts_fake( topk_ids: torch.Tensor, activation: str = "silu", apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, + quant_config: Optional[FusedMoEQuantConfig] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, @@ -1043,11 +1005,11 @@ def inplace_fused_experts_fake( w1_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None) -> None: + a2_scale: Optional[torch.Tensor] = None) -> None: pass +# TODO: get rid of these? replace with modular op? direct_register_custom_op( op_name="inplace_fused_experts", op_func=inplace_fused_experts, @@ -1065,11 +1027,7 @@ def outplace_fused_experts( topk_ids: torch.Tensor, activation: str = "silu", apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, + quant_config: Optional[FusedMoEQuantConfig] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, @@ -1077,15 +1035,12 @@ def outplace_fused_experts( w1_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None) -> torch.Tensor: + a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, False, activation, apply_router_weight_on_input, - use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, - use_int4_w4a16, per_channel_quant, - global_num_experts, expert_map, w1_scale, - w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, - block_shape) + quant_config, global_num_experts, expert_map, + w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, + a2_scale) def outplace_fused_experts_fake( @@ -1095,11 +1050,8 @@ def outplace_fused_experts_fake( topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, + apply_router_weight_on_input: bool = False, + quant_config: Optional[FusedMoEQuantConfig] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, @@ -1107,8 +1059,7 @@ def outplace_fused_experts_fake( w1_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None) -> torch.Tensor: + a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1145,11 +1096,7 @@ def fused_experts(hidden_states: torch.Tensor, inplace: bool = False, activation: str = "silu", apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, + quant_config: Optional[FusedMoEQuantConfig] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, @@ -1158,11 +1105,14 @@ def fused_experts(hidden_states: torch.Tensor, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, allow_deep_gemm: bool = False) -> torch.Tensor: + + if quant_config is None: + quant_config = FusedMoEQuantConfig() + # For now, disable DeepGemm for small N (<= 512) until better # permute/unpermute ops are available. - N = w1.shape[1] + N = w1.size(1) if (allow_deep_gemm and use_fp8_w8a8 and N > 512 and _valid_deep_gemm(hidden_states, w1, w2)): assert apply_router_weight_on_input is False @@ -1191,11 +1141,7 @@ def fused_experts(hidden_states: torch.Tensor, topk_ids=topk_ids, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - per_channel_quant=per_channel_quant, + quant_config=quant_config, global_num_experts=global_num_experts, expert_map=expert_map, w1_scale=w1_scale, @@ -1203,8 +1149,7 @@ def fused_experts(hidden_states: torch.Tensor, w1_zp=w1_zp, w2_zp=w2_zp, a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_shape) + a2_scale=a2_scale) def fused_experts_impl( @@ -1216,11 +1161,7 @@ def fused_experts_impl( inplace: bool = False, activation: str = "silu", apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, + quant_config: Optional[FusedMoEQuantConfig] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, @@ -1229,17 +1170,18 @@ def fused_experts_impl( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, ) -> torch.Tensor: + assert quant_config is not None + # Check constraints. - if use_int4_w4a16: - assert hidden_states.shape[1] // 2 == w1.shape[ - 2], "Hidden size mismatch" + if quant_config.use_int4_w4a16: + assert hidden_states.size(1) // 2 == w1.size(2), ( + "Hidden size mismatch") else: - assert hidden_states.shape[1] == w1.shape[2], ( - f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}") + assert hidden_states.size(1) == w1.size(2), ( + f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}") - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.stride(-1) == 1, "Stride of last dimension must be 1" assert w2.stride(-1) == 1, "Stride of last dimension must be 1" @@ -1247,32 +1189,23 @@ def fused_experts_impl( torch.float32, torch.float16, torch.bfloat16 ] - num_tokens = hidden_states.shape[0] - E, N, _ = w1.shape - K = w2.shape[1] + num_tokens = hidden_states.size(0) + E, N, _ = w1.size() + K = w2.size(1) if global_num_experts == -1: global_num_experts = E - top_k_num = topk_ids.shape[1] + top_k_num = topk_ids.size(1) # We execute the fused_moe kernel in chunks to circumvent this issue: # https://github.com/vllm-project/vllm/issues/5938 CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE M = min(num_tokens, CHUNK_SIZE) - config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - dtype=hidden_states.dtype) - - qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16) get_config_func = functools.partial( try_get_optimal_moe_config, - w1.shape, - w2.shape, + w1.size(), + w2.size(), top_k_num, - config_dtype, + quant_config.config_dtype_str, block_shape=block_shape, ) @@ -1310,7 +1243,7 @@ def fused_experts_impl( min((chunk + 1) * CHUNK_SIZE, num_tokens)) curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] - tokens_in_chunk, _ = curr_hidden_states.shape + tokens_in_chunk, _ = curr_hidden_states.size() if tokens_in_chunk == 0: break @@ -1322,7 +1255,7 @@ def fused_experts_impl( # do not need to be adjusted. intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * - topk_ids.shape[1]] + topk_ids.size(1)] intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] config = get_config_func(tokens_in_chunk) @@ -1332,8 +1265,8 @@ def fused_experts_impl( qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( A=curr_hidden_states, A_scale=a1_scale, - qtype=qtype, - per_channel_quant=per_channel_quant, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=per_channel_quant, block_shape=block_shape) sorted_token_ids, expert_ids, num_tokens_post_padded = ( @@ -1354,12 +1287,7 @@ def fused_experts_impl( top_k_num, config, compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - per_channel_quant=per_channel_quant, - block_shape=block_shape) + quant_config) if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, @@ -1373,8 +1301,8 @@ def fused_experts_impl( qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( A=intermediate_cache2, A_scale=a2_scale, - qtype=qtype, - per_channel_quant=per_channel_quant, + quant_dtype=qtype, + per_act_token_quant=per_channel_quant, block_shape=block_shape) invoke_fused_moe_kernel(qintermediate_cache2, @@ -1391,14 +1319,9 @@ def fused_experts_impl( 1, config, compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - per_channel_quant=per_channel_quant, - block_shape=block_shape) - - ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), + quant_config) + + ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), out_hidden_states[begin_chunk_idx:end_chunk_idx]) return out_hidden_states @@ -1417,11 +1340,7 @@ def fused_moe( num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, + quant_config: Optional[FusedMoEQuantConfig] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, @@ -1430,7 +1349,6 @@ def fused_moe( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -1452,20 +1370,12 @@ def fused_moe( - topk_group: Optional[int]: additional parameter for grouped_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk note: Deepseekv2 model uses grouped_topk - - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 - activation to compute the inner products for w1 and w2. - Defaults to False. - - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 - activation to compute the inner products for w1 and w2. - Defaults to False. + - quant_config (Optional[FusedMoEQuantConfig]): Quantization parameters for + this MoE op. - global_num_experts (int): The total number of experts in the global expert space. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert parallel shard. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. @@ -1475,8 +1385,6 @@ def fused_moe( a1. - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. - - block_shape: (Optional[list[int]]): Optional block size for block-wise - quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -1501,11 +1409,7 @@ def fused_moe( topk_ids, inplace=inplace, activation=activation, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - per_channel_quant=per_channel_quant, + quant_config=quant_config, global_num_experts=global_num_experts, expert_map=expert_map, w1_scale=w1_scale, @@ -1513,38 +1417,30 @@ def fused_moe( w1_zp=w1_zp, w2_zp=w2_zp, a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_shape) + a2_scale=a2_scale) class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - per_channel_quant: bool, - block_shape: Optional[list[int]] = None, - block_m: Optional[int] = None, + quant_config: Optional[FusedMoEQuantConfig] = None, ): - super().__init__() - self.use_fp8_w8a8 = use_fp8_w8a8 - self.use_int4_w4a16 = use_int4_w4a16 - self.use_int8_w8a8 = use_int8_w8a8 - self.use_int8_w8a16 = use_int8_w8a16 - self.block_shape = block_shape - self.block_m = block_m - self.qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16) - self.per_channel_quant = per_channel_quant + super().__init__(quant_config) + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) def supports_chunking(self) -> bool: return True + def supports_expert_map(self) -> bool: + return True + def workspace_shapes( self, a: torch.Tensor, @@ -1582,7 +1478,7 @@ def apply( expert_num_tokens: Optional[torch.Tensor], ): # Check constraints. - if self.use_int4_w4a16: + if self.quant_config.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( "Hidden size mismatch") else: @@ -1605,16 +1501,11 @@ def apply( if global_num_experts == -1: global_num_experts = E - config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - dtype=hidden_states.dtype) - config = try_get_optimal_moe_config( - w1.shape, - w2.shape, + w1.size(), + w2.size(), top_k_num, - config_dtype, + quant_config.config_dtype_str, num_tokens, block_shape=self.block_shape, ) @@ -1660,7 +1551,7 @@ def apply( use_int8_w8a8=self.use_int8_w8a8, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, - per_channel_quant=self.per_channel_quant, + per_channel_quant=self.per_act_token_quant, block_shape=self.block_shape) self.activation(activation, intermediate_cache2, @@ -1669,8 +1560,8 @@ def apply( a2q_scale: Optional[torch.Tensor] = None qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( - intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant, - self.block_shape) + intermediate_cache2, a2_scale, self.quant_dtype, + self.per_act_token_quant, self.block_shape) invoke_fused_moe_kernel(qintermediate_cache2, w2, @@ -1690,7 +1581,7 @@ def apply( use_int8_w8a8=self.use_int8_w8a8, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, - per_channel_quant=self.per_channel_quant, + per_channel_quant=self.per_act_token_quant, block_shape=self.block_shape) @@ -1699,27 +1590,17 @@ def modular_triton_fused_moe( use_int8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, - per_channel_quant: bool, + per_act_token_quant: bool, block_shape: Optional[list[int]] = None, ) -> mk.FusedMoEModularKernel: - qtype = get_config_qtype( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - ) return mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP( - quant_dtype=qtype, - per_channel_quant=per_channel_quant, - block_shape=block_shape, - ), + MoEPrepareAndFinalizeNoEP(), TritonExperts( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, - per_channel_quant=per_channel_quant, + per_act_token_quant=per_act_token_quant, block_shape=block_shape, ), ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 1fd8f217588..bbdbaf5afbe 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -3,26 +3,26 @@ import importlib from abc import abstractmethod -from dataclasses import dataclass from enum import Enum -from typing import Callable, Optional, Union +from typing import Callable, Optional import torch import torch.nn.functional as F -from compressed_tensors.quantization import (QuantizationArgs, - QuantizationStrategy, - QuantizationType) from torch.nn.parameter import UninitializedParameter import vllm.envs as envs -from vllm.config import ParallelConfig, get_current_vllm_config +from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_ep_group, - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEActivationFormat, FusedMoEModularKernel, + FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled) from vllm.model_executor.layers.quantization.base_config import ( @@ -38,18 +38,13 @@ if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts from .fused_moe import TritonExperts, fused_experts - from .modular_kernel import (FusedMoEModularKernel, - FusedMoEPermuteExpertsUnpermute, - FusedMoEPrepareAndFinalize) if has_pplx: - from .pplx_prepare_finalize import PplxPrepareAndFinalize + from .pplx_prepare_finalize import (PplxPrepareAndFinalize, + pplx_hidden_dim_scale_bytes) if has_deepep: from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize from .deepep_ll_prepare_finalize import DeepEPLLPrepareAndFinalize -else: - fused_experts = None # type: ignore - FusedMoEPermuteExpertsUnpermute = None # type: ignore - FusedMoEPrepareAndFinalize = None # type: ignore + if is_rocm_aiter_moe_enabled(): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 rocm_aiter_grouped_topk as grouped_topk) @@ -59,207 +54,8 @@ from .moe_pallas import fused_moe as fused_moe_pallas else: fused_moe_pallas = None # type: ignore -logger = init_logger(__name__) - - -@dataclass -class FusedMoEParallelConfig: - tp_size: int - dp_size: int - ep_size: int - tp_rank: int - dp_rank: int - ep_rank: int - - use_ep: bool # whether to use EP or not - - @property - def use_all2all_kernels(self): - return self.dp_size > 1 and self.use_ep - - @property - def use_pplx_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "pplx") - - @property - def use_deepep_ht_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput") - - @property - def use_deepep_ll_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") - - @staticmethod - def make(tp_size_: int, dp_size_: int, - vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": - """ - Determine MoE parallel configuration. Based on the input tp_size_, - dp_size_, ep_size_ and vllm's parallel config, determine what - level's of parallelism to use in the fused moe layer. - - Args: - tp_size_ (int): tp_size passed into the FusedMoE constructor. - dp_size_ (int): dp_size passed into the FusedMoE constructor. - ep_size_ (int): ep_size passed into the FusedMoE constructor. - vllm_parallel_config (ParallelConfig): vllm's parallel config - object. - - Examples: - When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1, - we simply return the sizes unaltered and the ranks set to 0. - - Expert Parallelism is considered only when either dp_size_ or tp_size_ - is non trivial. - - When TP = 2, DP = 1 and EP = False, the configuration on different - devices, - - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // - legend : {size, rank} - - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0} - - Comment : Tensors are sharded across 2 devices. - - When TP = 1, DP = 2 and EP = False, the configuration on different - devices, - - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0} - - device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0} - - Comment: There are 2 engine instances and the tensors are sharded - across 2 decvices. - - When TP = 2, DP = 2 and EP = False, the configuration on different - devices, - - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0} - - device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0} - - device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0} - - device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0} - - Comment: There are 2 engine instances and the tensors are sharded - across 4 devices. - - When, TP = 2, DP = 1 and EP = True, the configuration on different - devices, - - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0} - - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1} - - Comment: The experts are split between the 2 devices. - - When, TP = 1, DP = 2 and EP = True, the configuration on different - devices, - - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0} - - device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1} - - Comment: There are 2 engine instances and the experts are split - between the 2 devices. - - When TP = 2, DP = 2 and EP = True, the configuration on different - devices, - - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0} - - device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1} - - device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2} - - device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3} - - Comment: There are 2 engine instances and the experts are split - between the 4 devices. - """ - - def flatten_tp_across_dp(dp_rank: int): - tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank() - # There are actually dp_size_ * tp_size_ devices. Update tp_size - # and tp_rank so we shard across all devices. - tp_size = dp_size_ * tp_size_ - tp_rank = dp_rank * tp_size_ + tp_rank - return tp_size, tp_rank - - use_ep = (dp_size_ * tp_size_ > 1 - and vllm_parallel_config.enable_expert_parallel) - - dp_size = dp_size_ - dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 - tp_size, tp_rank = flatten_tp_across_dp(dp_rank) - - if not use_ep: - return FusedMoEParallelConfig(tp_size=tp_size, - tp_rank=tp_rank, - dp_size=dp_size, - dp_rank=dp_rank, - ep_size=1, - ep_rank=0, - use_ep=False) - # DP + EP / TP + EP / DP + TP + EP - assert use_ep - # In EP, each device owns a set of experts fully. There is no tensor - # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. - ep_size = tp_size - ep_rank = tp_rank - return FusedMoEParallelConfig(tp_size=1, - tp_rank=0, - dp_size=dp_size, - dp_rank=dp_rank, - ep_size=ep_size, - ep_rank=ep_rank, - use_ep=True) - - -# Adapted from pplx-kernels tests/all_to_all_utils.py -@dataclass -class MoEConfig: - num_experts: int - experts_per_token: int - hidden_dim: int - - num_local_experts: int - moe_parallel_config: FusedMoEParallelConfig - - in_dtype: torch.dtype # The activation type. - quant_dtype: torch.dtype = None - - # TODO: add more quantization params, blocked, per-token, etc. - block_size: int = 128 - - max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE - - def __post_init__(self): - if self.dp_size > 1: - logger.debug("Using MOEConfig::max_num_tokens=%d", - self.max_num_tokens) - - @property - def tp_size(self): - return self.moe_parallel_config.tp_size - - @property - def dp_size(self): - return self.moe_parallel_config.dp_size - - @property - def ep_size(self): - return self.moe_parallel_config.ep_size - - @property - def tp_rank(self): - return self.moe_parallel_config.tp_rank - - @property - def dp_rank(self): - return self.moe_parallel_config.dp_rank - - @property - def ep_rank(self): - return self.moe_parallel_config.ep_rank - - @property - def use_ep(self): - return self.moe_parallel_config.use_ep - - @property - def use_pplx_kernels(self): - return self.moe_parallel_config.use_pplx_kernels - - @property - def use_deepep_ht_kernels(self): - return self.moe_parallel_config.use_deepep_ht_kernels - @property - def use_deepep_ll_kernels(self): - return self.moe_parallel_config.use_deepep_ll_kernels +logger = init_logger(__name__) class FusedMoeWeightScaleSupported(Enum): @@ -269,21 +65,9 @@ class FusedMoeWeightScaleSupported(Enum): BLOCK = "block" -def get_quant_config_input_activations( - quant_config: Optional[QuantizationConfig] -) -> Optional[QuantizationArgs]: - if (quant_config is not None and hasattr(quant_config, 'target_scheme_map') - and "Linear" in quant_config.target_scheme_map and - "input_activations" in quant_config.target_scheme_map["Linear"]): - return quant_config.target_scheme_map["Linear"].get( - "input_activations") - else: - return None - - class FusedMoEMethodBase(QuantizeMethodBase): - moe: MoEConfig + moe: FusedMoEConfig @abstractmethod def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -291,23 +75,25 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, params_dtype: torch.dtype, **extra_weight_attrs): raise NotImplementedError - def init_prepare_finalize(self, moe: MoEConfig, + def init_prepare_finalize(self, moe: FusedMoEConfig, quant_config: Optional[QuantizationConfig]): all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None self.moe = moe - quant_dtype = None - act_quant_block_size = None - from vllm.model_executor.layers.quantization.fp8 import Fp8Config - if isinstance(quant_config, Fp8Config): - act_quant_block_size = quant_config.weight_block_size - quant_dtype = torch.float8_e4m3fn - - prepare_finalize: Optional[Union[PplxPrepareAndFinalize, - DeepEPHTPrepareAndFinalize, - DeepEPLLPrepareAndFinalize]] = None + + prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None + if moe.use_pplx_kernels: + hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( + moe.max_num_tokens, + moe.hidden_dim, + moe.in_dtype, + moe.quant_dtype, + per_act_token_quant=moe.per_act_token_quant, + block_shape=moe.block_shape, + ) + all_to_all_args = dict( max_num_tokens=moe.max_num_tokens, num_experts=moe.num_experts, @@ -317,14 +103,8 @@ def init_prepare_finalize(self, moe: MoEConfig, # dp_size actually means tp_size, bug in pplx kernels dp_size=all2all_manager.tp_group.world_size, hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize, - # For blocked per token: set to - # ceil_div(hidden_dim, block_size) * sizeof(float32) - # For per-token: set to sizeof(float32) - hidden_dim_scale_bytes=( - 0 if moe.quant_dtype.itemsize != 1 else - ((moe.hidden_dim + moe.block_size - 1) // moe.block_size * - torch.float32.itemsize)), + hidden_dim_bytes=hidden_dim_bytes, + hidden_dim_scale_bytes=hidden_scale_bytes, ) # Intranode pplx a2a takes a group name while internode does not. @@ -334,20 +114,12 @@ def init_prepare_finalize(self, moe: MoEConfig, handle = all2all_manager.get_handle(all_to_all_args) - input_activations = get_quant_config_input_activations( - quant_config) - prepare_finalize = PplxPrepareAndFinalize( handle, max_num_tokens=moe.max_num_tokens, world_size=all2all_manager.world_size, rank=all2all_manager.rank, - # dp_size actually means tp_size, bug in pplx kernels - dp_size=all2all_manager.tp_group.world_size, - quant_dtype=moe.quant_dtype, - per_act_token=(input_activations.strategy - == QuantizationStrategy.TOKEN - if input_activations is not None else False), + dp_size=moe.dp_size, ) elif moe.use_deepep_ht_kernels: assert moe.dp_size == all2all_manager.dp_world_size @@ -361,8 +133,6 @@ def init_prepare_finalize(self, moe: MoEConfig, dp_size=all2all_manager.dp_world_size, rank_expert_offset=all2all_manager.rank * moe.num_local_experts, - quant_dtype=quant_dtype, - block_shape=act_quant_block_size, ) elif moe.use_deepep_ll_kernels: @@ -381,16 +151,14 @@ def init_prepare_finalize(self, moe: MoEConfig, # profiling. Turning it off for now. prepare_finalize = DeepEPLLPrepareAndFinalize( handle, + max_tokens_per_rank=moe.max_num_tokens, world_size=all2all_manager.world_size, dp_size=all2all_manager.dp_world_size, - max_tokens_per_rank=moe.max_num_tokens, - quant_dtype=quant_dtype, - block_shape=act_quant_block_size, - use_fp8_dispatch=False, ) self.topk_indices_dtype = None if prepare_finalize is not None: + logger.debug("%s", prepare_finalize.__class__.__name__) self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() experts = self.select_gemm_impl(prepare_finalize, moe) self.fused_experts = FusedMoEModularKernel( @@ -399,13 +167,15 @@ def init_prepare_finalize(self, moe: MoEConfig, ) def select_gemm_impl( - self, prepare_finalize: FusedMoEPrepareAndFinalize, - moe: Optional[MoEConfig]) -> FusedMoEPermuteExpertsUnpermute: + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> FusedMoEPermuteExpertsUnpermute: # based on the all2all implementation, select the appropriate # gemm implementation raise NotImplementedError( - "Subclass must select appropriate gemm implementation" - " based on the prepare_finalize") + f"{self.__class__.__name__} must select appropriate gemm " + "implementation based on the prepare_finalize") @abstractmethod def apply( @@ -433,7 +203,7 @@ def apply( class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" - def __init__(self, moe: MoEConfig): + def __init__(self, moe: FusedMoEConfig): super().__init__() self.fused_experts = fused_experts # type: ignore self.topk_indices_dtype = None @@ -446,44 +216,28 @@ def __init__(self, moe: MoEConfig): else: self.rocm_aiter_fused_experts = None # type: ignore - def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize, - moe: Optional[MoEConfig]): + def select_gemm_impl( + self, prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig) -> FusedMoEPermuteExpertsUnpermute: assert self.fused_experts == fused_experts all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None - experts: Optional[FusedMoEPermuteExpertsUnpermute] = None - - use_batched_experts = prepare_finalize.max_num_tokens_per_rank( - ) is not None - if use_batched_experts: + if (prepare_finalize.activation_format == + FusedMoEActivationFormat.BatchedExperts): logger.debug("BatchedTritonExperts %s", self.moe) assert self.moe.dp_size == all2all_manager.dp_world_size - experts = BatchedTritonExperts( + return BatchedTritonExperts( max_num_tokens=self.moe.max_num_tokens, world_size=all2all_manager.world_size, # dp_size actually means tp_size, bug in pplx kernels dp_size=all2all_manager.tp_group.world_size, - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - block_shape=None, - per_channel_quant=False, ) else: logger.debug("TritonExperts %s", self.moe) - experts = TritonExperts( - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - block_shape=None, - per_channel_quant=False, - ) - return experts + return TritonExperts() def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -840,6 +594,7 @@ def __init__( if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype + all2all_manager = get_ep_group().device_communicator.all2all_manager vllm_config = get_current_vllm_config() self.moe_parallel_config: FusedMoEParallelConfig = ( @@ -848,18 +603,18 @@ def __init__( get_tensor_model_parallel_world_size()), dp_size_=(dp_size if dp_size is not None else get_dp_group().world_size), + world_size_=(all2all_manager.world_size + if all2all_manager is not None else 1), vllm_parallel_config=vllm_config.parallel_config)) self.global_num_experts = num_experts # For smuggling this layer into the fused moe custom op - self.use_direct_call = self.dp_size == 1 - if not self.use_direct_call: - compilation_config = vllm_config.compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError("Duplicate layer name: {}".format(prefix)) - compilation_config.static_forward_context[prefix] = self - self.layer_name = prefix + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError("Duplicate layer name: {}".format(prefix)) + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix # Determine expert maps if self.use_ep: @@ -896,25 +651,22 @@ def __init__( from vllm_hpu_extension.ops import DynamicFusedMOE self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) - # Only support float8 for now. - quant_dtype = params_dtype - if quant_config is not None: - input_activations = get_quant_config_input_activations( - quant_config) - if (input_activations is not None - and input_activations.num_bits == 8 - and input_activations.type == QuantizationType.FLOAT): - quant_dtype = torch.float8_e4m3fn - - moe = MoEConfig( + if vllm_config.model_config is not None: + model_dtype = vllm_config.model_config.dtype + else: + # TODO (bnell): This is a hack to get test_mixtral_moe to work + # since model_config is not set in the pytest test. + model_dtype = params_dtype + + moe = FusedMoEConfig.make( num_experts=self.global_num_experts, experts_per_token=top_k, hidden_dim=hidden_size, num_local_experts=self.local_num_experts, moe_parallel_config=self.moe_parallel_config, - in_dtype=params_dtype, - quant_dtype=quant_dtype, + in_dtype=model_dtype, max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, + quant_config=quant_config, ) self.moe_config = moe self.quant_config = quant_config @@ -951,15 +703,14 @@ def __init__( self.batched_router_logits: Optional[torch.Tensor] = None if (self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels): - act_dtype = vllm_config.model_config.dtype self.batched_hidden_states = torch.zeros( - (envs.VLLM_MOE_DP_CHUNK_SIZE, self.hidden_size), - dtype=act_dtype, + (moe.max_num_tokens, self.hidden_size), + dtype=moe.in_dtype, device=torch.cuda.current_device()) self.batched_router_logits = torch.zeros( - (envs.VLLM_MOE_DP_CHUNK_SIZE, self.global_num_experts), - dtype=act_dtype, + (moe.max_num_tokens, self.global_num_experts), + dtype=moe.in_dtype, device=torch.cuda.current_device()) @property @@ -1353,11 +1104,8 @@ def maybe_all_reduce_tensor_model_parallel( def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - if self.use_direct_call: - return self.forward_impl(hidden_states, router_logits) - else: - return torch.ops.vllm.moe_forward(hidden_states, router_logits, - self.layer_name) + return torch.ops.vllm.moe_forward(hidden_states, router_logits, + self.layer_name) def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): @@ -1380,7 +1128,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): assert (self.batched_hidden_states.size(0) # type: ignore >= chunk_size) - assert (self.batched_router_logits.size(0) # type: ignore + assert (self.batched_router_logits.size(0) # type: ignore >= chunk_size) staged_hidden_states = self.batched_hidden_states[: chunk_size, :] # type: ignore diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index ed3b6b8a1af..1557ce3fb82 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -1,12 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from enum import Enum from math import prod -from typing import Optional +from typing import Optional, final import torch import vllm.envs as envs +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.utils import cdiv @@ -82,6 +84,18 @@ def _moe_problem_size( return E, M, N, K, topk +class FusedMoEActivationFormat(Enum): + """ + The standard activation format (num_tokens, hidden dim). + """ + Standard = "standard", + """ + The batched experts format (num experts, max tokens per expert, hidden dim) + """ + BatchedExperts = "batched_experts", + + +# TODO: pass FusedMoEParallelConfig in as ctor parameter? class FusedMoEPrepareAndFinalize(ABC): """ An abstract base class for the [Quantize-Prepare] and [Finalize] steps @@ -99,6 +113,7 @@ def prepare( num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: """ @@ -148,6 +163,15 @@ def finalize( """ raise NotImplementedError + @property + @abstractmethod + def activation_format(self) -> FusedMoEActivationFormat: + """ + A property indicating the output format of the activations for the + 'prepare' method. + """ + raise NotImplementedError + @abstractmethod def topk_indices_dtype(self) -> Optional[torch.dtype]: """ @@ -176,6 +200,41 @@ class FusedMoEPermuteExpertsUnpermute(ABC): above. """ + def __init__( + self, + quant_config: Optional[FusedMoEQuantConfig], + ): + if quant_config is not None: + self.quant_config = quant_config + else: + self.quant_config = FusedMoEQuantConfig() + + @property + @abstractmethod + def activation_formats( + self) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]: + """ + A property which is a tuple of the input and output activation formats + for the 'apply' method. + """ + raise NotImplementedError + + @property + def quant_dtype(self) -> Optional[torch.dtype]: + return self.quant_config.quant_dtype + + @property + def block_shape(self) -> Optional[list[int]]: + return self.quant_config.block_shape + + @property + def per_act_token_quant(self) -> bool: + return self.quant_config.per_act_token_quant + + @property + def per_out_ch_quant(self) -> bool: + return self.quant_config.per_out_ch_quant + # TODO (bnell): make this return a CHUNK_SIZE or None instead? @abstractmethod def supports_chunking(self) -> bool: @@ -185,6 +244,13 @@ def supports_chunking(self) -> bool: """ raise NotImplementedError + @abstractmethod + def supports_expert_map(self) -> bool: + """ + A flag indicating whether or not this class supports expert maps + """ + raise NotImplementedError + @abstractmethod def workspace_shapes( self, @@ -293,6 +359,7 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int, return None +@final class FusedMoEModularKernel(torch.nn.Module): """ This class combines a FusedMoEPrepareAndFinalize instance and @@ -314,6 +381,8 @@ def __init__( super().__init__() self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts + assert prepare_finalize.activation_format == \ + fused_experts.activation_formats[0] def forward( self, @@ -379,8 +448,16 @@ def forward( (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids, _expert_topk_weights) = self.prepare_finalize.prepare( - a1, a1_scale, a2_scale, topk_weights, topk_ids, - global_num_experts, expert_map, apply_router_weight_on_input) + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, + ) # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 5bc01dbf202..099ac1867b1 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -6,33 +6,74 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) +from vllm.utils import cdiv, round_up + + +def pplx_hidden_dim_scale_bytes( + max_num_tokens: int, + hidden_dim: int, + in_dtype: torch.dtype, + quant_dtype: Optional[torch.dtype], + per_act_token_quant: bool, + block_shape: Optional[list[int]], +): + # For blocked per token: set to + # ceil_div(hidden_dim, block_size) * sizeof(float32) + # For per-token: set to 4 * sizeof(float32) (x4 for alignment) + if quant_dtype is not None: + assert quant_dtype.itemsize == 1 + hidden_dim_bytes = hidden_dim * quant_dtype.itemsize + elem_size = torch.float32.itemsize + align = 16 + + if per_act_token_quant: + # per-token + assert block_shape is None + hidden_scale_bytes = round_up(max_num_tokens * elem_size, align) + elif block_shape is not None: + # per-group + block_size = block_shape[1] + num_blocks = cdiv(hidden_dim, block_size) + hidden_scale_bytes = round_up(num_blocks * elem_size, align) + else: + # per-tensor + # ? + hidden_scale_bytes = round_up(elem_size, align) + else: + hidden_dim_bytes = hidden_dim * in_dtype.itemsize + hidden_scale_bytes = 0 + + #print(f"pplx bytes {hidden_dim_bytes}, {hidden_scale_bytes}") + + return hidden_dim_bytes, hidden_scale_bytes # The max_num_tokens, world_size and dp_size must be the same # as the ones used to create the AllToAll. class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): - def __init__(self, - a2a: pplx.AllToAll, - max_num_tokens: int, - world_size: int, - rank: int, - dp_size: int, - quant_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None, - per_act_token: bool = False): + def __init__( + self, + a2a: pplx.AllToAll, + max_num_tokens: int, + world_size: int, + rank: int, + dp_size: int, + ): super().__init__() assert max_num_tokens > 0 self.a2a = a2a - self.block_shape = block_shape self.max_num_tokens = max_num_tokens self.world_size = world_size self.rank = rank self.dp_size = dp_size - self.quant_dtype = quant_dtype - self.per_act_token = per_act_token + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts def max_num_tokens_per_rank(self) -> Optional[int]: return self.max_num_tokens @@ -45,34 +86,36 @@ def prepare( a1: torch.Tensor, a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], - rank_topk_weights: torch.Tensor, - rank_topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: num_tokens = a1.size(0) # M hidden_dim = a1.size(-1) # K - assert rank_topk_ids.size(0) == num_tokens + assert topk_ids.size(0) == num_tokens # assert expert_map is None, "NYI" # Is this always going to be a1.device? device = a1.device if apply_router_weight_on_input: - topk = rank_topk_ids.size(1) + topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 assert topk == 1, ( "apply_router_weight_on_input is only implemented for topk=1") - a1 = a1 * rank_topk_weights.to(a1.dtype) + a1 = a1 * topk_weights.to(a1.dtype) repeat_cols = 4 - repeat_rows = 1 if self.per_act_token else a1.shape[0] + repeat_rows = 1 if quant_config.per_act_token_quant else a1.shape[0] a1q, a1q_scale = moe_kernel_quantize_input( - a1, (None if self.per_act_token else a1_scale), self.quant_dtype, - self.per_act_token, self.block_shape) + a1, (None if quant_config.per_act_token_quant else a1_scale), + quant_config.quant_dtype, quant_config.per_act_token_quant, + quant_config.block_shape) if a1q_scale is not None: a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols) @@ -99,8 +142,8 @@ def prepare( expert_x_scale: Optional[torch.Tensor] = None if a1q.dtype.itemsize == 1: float32_size = torch.float32.itemsize - block_size = (self.block_shape[0] if self.block_shape is not None - else 1) * float32_size + block_size = (quant_config.block_shape[1] if quant_config. + block_shape is not None else 1) * float32_size expert_x_scale = torch.empty( ( num_local_experts, @@ -121,7 +164,7 @@ def prepare( out_expert_x_scale=expert_x_scale, dp_x=a1q, dp_x_scale=a1q_scale, - indices=rank_topk_ids, + indices=topk_ids, bound_m=bound_m, ) if expert_x_scale is not None: diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 9ed95e1de9f..9e4be82f6c1 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -5,6 +5,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( _moe_unpermute_and_reduce) from vllm.model_executor.layers.fused_moe.utils import ( @@ -13,16 +14,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): - def __init__( - self, - quant_dtype: Optional[torch.dtype] = None, - per_channel_quant: bool = False, - block_shape: Optional[list[int]] = None, - ): - super().__init__() - self.per_channel_quant = per_channel_quant - self.block_shape = block_shape - self.quant_dtype = quant_dtype + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard def max_num_tokens_per_rank(self) -> Optional[int]: return None @@ -39,7 +33,8 @@ def prepare( topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool = False, + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -50,10 +45,9 @@ def prepare( "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) - a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, - self.quant_dtype, - self.per_channel_quant, - self.block_shape) + a1q, a1q_scale = moe_kernel_quantize_input( + a1, a1_scale, quant_config.quant_dtype, + quant_config.per_act_token_quant, quant_config.block_shape) return a1q, a1q_scale, None, None, None diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index d44989cce72..05164975bba 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -308,13 +308,20 @@ def rocm_aiter_fused_experts( topk_ids: torch.Tensor, activation: str = "silu", apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - per_channel_quant: bool = False, + quant_config: Optional[FusedMoEQuantConfig] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None) -> torch.Tensor: + a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: + + if quant_config is None: + use_fp8_w8a8 = False + block_shape = None + per_channel_quant = False + else: + use_fp8_w8a8 = quant_config.use_fp8_w8a8 + block_shape = quant_config.block_shape + per_channel_quant = quant_config.per_act_token_quant activation_method = (ActivationMethod.SILU if activation == "silu" else ActivationMethod.GELU) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 4bbfea446e2..d674311330a 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -5,6 +5,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts @@ -12,34 +13,39 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, - block_shape: Optional[list[int]] = None, - block_m: Optional[int] = None, - allow_deep_gemm: bool = False): - super().__init__() - self.triton_expert = TritonExperts(use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int4_w4a16=use_int4_w4a16, - use_int8_w8a16=use_int8_w8a16, - per_channel_quant=per_channel_quant, - block_shape=block_shape, - block_m=block_m) - self.allow_deep_gemm = allow_deep_gemm - self.use_fp8_w8a8 = use_fp8_w8a8 + def __init__( + self, + allow_deep_gemm: bool = False, + quant_config: Optional[FusedMoEQuantConfig] = None, + ): + super().__init__(quant_config) + self.triton_expert = TritonExperts(quant_config) + self.allow_deep_gemm = (allow_deep_gemm and not per_act_token_quant + and use_fp8_w8a8) self.deep_gemm_expert = DeepGemmExperts( ) if self.allow_deep_gemm else None + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + assert (self.deep_gemm_expert is None + or self.triton_expert.activation_formats + == self.deep_gemm_expert.activation_formats) + return self.triton_expert.activation_formats + def supports_chunking(self) -> bool: dge = self.deep_gemm_expert te = self.triton_expert return ((dge is None or dge.supports_chunking()) and (te is None or te.supports_chunking())) + def supports_expert_map(self) -> bool: + dge = self.deep_gemm_expert + te = self.triton_expert + return ((dge is None or dge.supports_expert_map()) + and (te is None or te.supports_expert_map())) + def workspace_shapes( self, a: torch.Tensor, @@ -83,9 +89,7 @@ def apply( workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], ): - N = w1.size(1) - - use_deep_gemm = (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512 + use_deep_gemm = (self.allow_deep_gemm and _valid_deep_gemm(hidden_states, w1, w2)) experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 692482c2ea6..921af0d1a1b 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -75,14 +75,14 @@ def _int8_quantize( def moe_kernel_quantize_input( A: torch.Tensor, A_scale: Optional[torch.Tensor], - qtype: Optional[torch.dtype], - per_channel_quant: bool, + quant_dtype: Optional[torch.dtype], + per_act_token_quant: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - if qtype == torch.float8_e4m3fn: - return _fp8_quantize(A, A_scale, per_channel_quant, block_shape) - elif qtype == torch.int8: - return _int8_quantize(A, A_scale, per_channel_quant, block_shape) + if quant_dtype == torch.float8_e4m3fn: + return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) + elif quant_dtype == torch.int8: + return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) else: assert A_scale is None return A, A_scale diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index f14131c5f05..024a9ce4f46 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -14,8 +14,10 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, - FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe import ( + CutlassExpertsFp8, FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, + FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported, fused_experts) from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) from vllm.model_executor.layers.quantization.utils import replace_parameter @@ -32,13 +34,6 @@ has_pplx = importlib.util.find_spec("pplx_kernels") is not None -if current_platform.is_cuda_alike(): - from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedPrepareAndFinalize) - if has_pplx: - from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize) - logger = init_logger(__name__) @@ -304,15 +299,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: requires_grad=False) self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts - else: - from vllm.model_executor.layers.fused_moe import fused_experts - self.fused_experts_func = fused_experts - - if self.use_marlin: + elif self.use_marlin: prepare_moe_fp8_layer_for_marlin(layer, False) # Activations not quantized for marlin. del layer.w13_input_scale del layer.w2_input_scale + self.fused_experts_func = None + else: + self.fused_experts_func = fused_experts def apply( self, @@ -354,9 +348,11 @@ def apply( topk_ids=topk_ids, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=True, - per_channel_quant=self.weight_quant.strategy == - QuantizationStrategy.CHANNEL, + FusedMoEQuantConfig.make( + use_fp8_w8a8=True, + per_act_token_quant=self.weight_quant.strategy == + QuantizationStrategy.CHANNEL + ), w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, @@ -379,6 +375,8 @@ def apply( global_num_experts=global_num_experts, expert_map=expert_map) + assert self.fused_experts_func is not None + return self.fused_experts_func( hidden_states=x, w1=layer.w13_weight, @@ -388,9 +386,11 @@ def apply( inplace=True, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=True, - per_channel_quant=self.weight_quant.strategy == - QuantizationStrategy.CHANNEL, + FusedMoEQuantConfig.make( + use_fp8_w8a8=True, + per_act_token_quant=self.weight_quant.strategy == + QuantizationStrategy.CHANNEL + ), global_num_experts=global_num_experts, expert_map=expert_map, w1_scale=layer.w13_weight_scale, @@ -552,28 +552,27 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) - def select_gemm_impl(self, prepare_finalize, moe): - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - CutlassExpertsFp8) + def select_gemm_impl( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> FusedMoEPermuteExpertsUnpermute: - assert moe is not None + use_batched_format = (prepare_finalize.activation_format == + FusedMoEActivationFormat.BatchedExperts) + + num_experts = (moe.num_local_experts + if use_batched_format else moe.num_experts) - max_experts_per_worker = ( - (moe.num_experts + prepare_finalize.world_size - 1) // - prepare_finalize.world_size) experts = CutlassExpertsFp8( - max_experts_per_worker, + num_experts, moe.in_dtype, self.input_quant.strategy == QuantizationStrategy.TOKEN, self.weight_quant.strategy == QuantizationStrategy.CHANNEL, - use_batched_format=True, + use_batched_format=use_batched_format, ) - if has_pplx and isinstance( - prepare_finalize, - (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): - # no expert_map support in this case - self.disable_expert_map = True + self.disable_expert_map = not experts.supports_expert_map() return experts def apply( @@ -606,7 +605,8 @@ def apply( custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - indices_type=torch.uint32) + indices_type=self.topk_indices_dtype, + ) return self.fused_experts( x, @@ -746,8 +746,10 @@ def apply( inplace=True, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, - use_int8_w8a8=True, - per_channel_quant=True, + FusedMoEQuantConfig.make( + use_int8_w8a8=True, + per_act_token_quant=True, + ), global_num_experts=global_num_experts, expert_map=expert_map, w1_scale=layer.w13_weight_scale, @@ -1251,8 +1253,11 @@ def apply( topk_ids=topk_ids, inplace=True, activation=activation, - use_int4_w4a16=self.num_bits == 4, - use_int8_w8a16=self.num_bits == 8, + FusedMoEQuantConfig.make( + use_int4_w4a16=self.num_bits == 4, + use_int8_w8a16=self.num_bits == 8, + block_shape=[0, self.group_size], + ) global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, @@ -1260,4 +1265,4 @@ def apply( w2_scale=layer.w2_weight_scale, w1_zp=None, w2_zp=None, - block_shape=[0, self.group_size]) + ) diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 01b0064f080..c8692e01a71 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -140,7 +140,7 @@ def apply( topk_ids=topk_ids, inplace=True, activation=activation, - use_int8_w8a16=True, + FusedMoEQuantConfig.make(use_int8_w8a16=True), global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b3042bfaed3..a3597198e87 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -3,7 +3,7 @@ import functools import importlib.util -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional import torch import torch.nn.functional as F @@ -14,8 +14,11 @@ from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, - FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe import ( + BatchedTritonOrDeepGemmExperts, FusedMoE, FusedMoEActivationFormat, + FusedMoEConfig, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported, + TritonOrDeepGemmExperts) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -467,8 +470,10 @@ def __init__(self, quant_config: Fp8Config): self.topk_indices_dtype = None self.fused_experts = functools.partial( # type: ignore fused_experts, - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, + FusedMoEQuantConfig.make( + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size + ), allow_deep_gemm=self.allow_deep_gemm) def create_weights(self, layer: Module, num_experts: int, hidden_size: int, @@ -770,44 +775,51 @@ def process_weights_after_loading(self, layer: Module) -> None: del layer.w13_input_scale del layer.w2_input_scale - def select_gemm_impl(self, prepare_finalize, moe): - - from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts) - from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts) - + def select_gemm_impl( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> FusedMoEPermuteExpertsUnpermute: assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( "Marlin and ROCm AITER are not supported with all2all yet.") - experts: Optional[Union[BatchedTritonOrDeepGemmExperts, - TritonOrDeepGemmExperts]] = None - max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() - use_batched_experts = max_num_tokens_per_rank is not None - - if use_batched_experts: - experts = BatchedTritonOrDeepGemmExperts( + if (prepare_finalize.activation_format == + FusedMoEActivationFormat.BatchedExperts): + max_num_tokens_per_rank = ( + prepare_finalize.max_num_tokens_per_rank()) + assert max_num_tokens_per_rank is not None + logger.debug( + "BatchedTritonOrDeepGemmExperts(%s): " + "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", + self.__class__.__name__, max_num_tokens_per_rank, + self.quant_config.weight_block_size, False) + return BatchedTritonOrDeepGemmExperts( max_num_tokens=max_num_tokens_per_rank, - world_size=prepare_finalize.world_size, - dp_size=prepare_finalize.dp_size, - use_fp8_w8a8=True, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_channel_quant=False, - block_shape=self.quant_config.weight_block_size, + world_size=moe.world_size, + dp_size=moe.dp_size, allow_deep_gemm=self.allow_deep_gemm, + FusedMoEQuantConfig.make( + moe.in_dtype, + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + per_act_token_quant=False + ), ) else: - experts = TritonOrDeepGemmExperts( - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, + logger.debug( + "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s", + self.__class__.__name__, self.quant_config.weight_block_size, + False) + return TritonOrDeepGemmExperts( allow_deep_gemm=self.allow_deep_gemm, + FusedMoEQuantConfig.make( + moe.in_dtype, + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + per_act_token_quant=False, + ), ) - assert experts is not None - return experts - def apply( self, layer: torch.nn.Module, @@ -851,7 +863,10 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, - use_fp8_w8a8=True, + FusedMoEQuantConfig.make( + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + ), apply_router_weight_on_input=apply_router_weight_on_input, w1_scale=(layer.w13_weight_scale_inv if self.block_quant else layer.w13_weight_scale), @@ -859,7 +874,7 @@ def apply( if self.block_quant else layer.w2_weight_scale), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - block_shape=self.quant_config.weight_block_size) + ) elif self.use_marlin: assert activation == "silu", ( f"{activation} not supported for Marlin MoE.") diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 3aa23f06825..daef5390e5e 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -322,8 +322,11 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - use_int4_w4a16=weight_bits == 4, - use_int8_w8a16=weight_bits == 8, + FusedMoEQuantConfig.make( + use_int4_w4a16=weight_bits == 4, + use_int8_w8a16=weight_bits == 8, + block_shape=[0, layer.group_size], + ), global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, @@ -331,7 +334,7 @@ def apply( w2_scale=layer.w2_scales, w1_zp=layer.w13_qzeros if has_zp else None, w2_zp=layer.w2_qzeros if has_zp else None, - block_shape=[0, layer.group_size]) + ) @staticmethod def get_weight_loader(layer, weight_loader): diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 4c2da4c8b04..06ac384b7a7 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -227,7 +227,7 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - use_fp8_w8a8=True, + FusedMoEQuantConfig.make(use_fp8_w8a8=True), global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map,