From 8de2fd39fcb60f1a9cb84a34c5245b2b991561fe Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 18 Jun 2025 07:32:15 -0700 Subject: [PATCH 1/9] deep_ep + use_fp8_dispatch Signed-off-by: Varun Sundar Rabindranath --- vllm/model_executor/layers/fused_moe/layer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 1fd8f217588..c6c908f73a2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -45,7 +45,8 @@ from .pplx_prepare_finalize import PplxPrepareAndFinalize if has_deepep: from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize - from .deepep_ll_prepare_finalize import DeepEPLLPrepareAndFinalize + from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE, + DeepEPLLPrepareAndFinalize) else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore @@ -377,6 +378,12 @@ def init_prepare_finalize(self, moe: MoEConfig, all2all_manager.world_size) handle = all2all_manager.get_handle(all_to_all_args) + # Note : We may want to use FP8 dispatch even otherwise just to + # reduce datamovement + use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype() + and act_quant_block_size + == DEEPEP_QUANT_BLOCK_SIZE) + # Note (varun): Whether to use FP8 dispatch or not needs some # profiling. Turning it off for now. prepare_finalize = DeepEPLLPrepareAndFinalize( @@ -386,7 +393,7 @@ def init_prepare_finalize(self, moe: MoEConfig, max_tokens_per_rank=moe.max_num_tokens, quant_dtype=quant_dtype, block_shape=act_quant_block_size, - use_fp8_dispatch=False, + use_fp8_dispatch=use_fp8_dispatch, ) self.topk_indices_dtype = None From 299f8291803fc49b2c9ccefd070d374688bffcc8 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 18 Jun 2025 20:09:51 +0000 Subject: [PATCH 2/9] DeepGEMM LL optimizations - Quantized dispatch - Fused act-and-mul-and-quant in the right layout for DeepGEMM Signed-off-by: Tyler Michael Smith --- .../layers/fused_moe/batched_deep_gemm_moe.py | 190 ++++++++++++++++-- 1 file changed, 174 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 5492399efdf..758cd7c56f7 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -6,14 +6,183 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, per_token_group_quant_fp8) +from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.triton_utils import tl, triton logger = init_logger(__name__) has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None +@triton.jit +def _silu_mul_fp8_quant_deep_gemm( + # Pointers ------------------------------------------------------------ + input_ptr, # *FP32 activations (E, T, 2*H) + y_q_ptr, # *FP8 quantised activations (E, T, H) + y_s_ptr, # *FP32 scales (E, T, G) + counts_ptr, # *INT32 number of tokens per expert (E) + + # Sizes --------------------------------------------------------------- + E: tl.constexpr, # num_experts + T: tl.constexpr, # max_num_tokens + H: tl.constexpr, # hidden dimension (per output) + GROUP_SIZE: tl.constexpr, # elements per group (usually 128) + + # Strides for input (elements) --------------------------------------- + stride_i_e, + stride_i_t, + stride_i_h, + + # Strides for y_q (elements) ----------------------------------------- + stride_yq_e, + stride_yq_t, + stride_yq_h, + + # Strides for y_s (elements) ----------------------------------------- + stride_ys_e, + stride_ys_t, + stride_ys_g, + + # Stride for counts (elements) + stride_counts_e, + + # Numeric params ------------------------------------------------------ + eps: tl.constexpr, + fp8_min: tl.constexpr, + fp8_max: tl.constexpr, + + # Meta --------------------------------------------------------------- + BLOCK: tl.constexpr, +): + G = H // GROUP_SIZE + + # map program id -> (e, g) + pid = tl.program_id(0) + e = pid // G + g = pid % G + + e = e.to(tl.int64) + g = g.to(tl.int64) + + # number of valid tokens for this expert + n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64) + + cols = tl.arange(0, BLOCK) + cols = cols.to(tl.int64) + mask_h = cols < BLOCK + + t = tl.zeros([], tl.int64) + while t < n_tokens: + base_i_offset = (e * stride_i_e + t * stride_i_t + + g * GROUP_SIZE * stride_i_h) + base_yq_offset = (e * stride_yq_e + t * stride_yq_t + + g * GROUP_SIZE * stride_yq_h) + base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g + + mask = mask_h + x = tl.load(input_ptr + base_i_offset + cols * stride_i_h, + mask=mask, + other=0.0).to(tl.float32) + y2 = tl.load(input_ptr + base_i_offset + H * stride_i_h + + cols * stride_i_h, + mask=mask, + other=0.0).to(tl.float32) + + x = x * (1.0 / (1.0 + tl.exp(-x))) + y = x * y2 + + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask) + tl.store(y_s_ptr + base_ys_offset, y_s) + + t += 1 + + +def silu_mul_fp8_quant_deep_gemm( + y: torch.Tensor, # (E, T, 2*H) float32 + tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert + group_size: int = 128, + eps: float = 1e-6, +): + """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales + + y has shape (E, T, 2*H). The first half of the last dimension is + silu-activated, multiplied by the second half, then quantized into FP8. + + Returns `(y_q, y_s)` where + * `y_q` is the FP8 tensor of shape `(E, T, H)`, same layout as `y[..., :H]`. + * `y_s` has shape `(E, T, H // group_size)` and strides `(T*G, 1, T)` + """ + assert y.ndim == 3, "y must be (E, T, 2*H)" + E, T, H2 = y.shape + assert H2 % 2 == 0, "last dim of y must be even (2*H)" + H = H2 // 2 + G = H // group_size + assert H % group_size == 0, "H must be divisible by group_size" + assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, \ + "tokens_per_expert must be shape (E,)" + tokens_per_expert = tokens_per_expert.to(device=y.device, + dtype=torch.int32) + + # allocate outputs + fp8_dtype = torch.float8_e4m3fn + y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) + + # strides (elements) + stride_i_e, stride_i_t, stride_i_h = y.stride() + stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() + + # desired scale strides (elements): (T*G, 1, T) + stride_ys_e = T * G + stride_ys_t = 1 + stride_ys_g = T + y_s = torch.empty_strided((E, T, G), + (stride_ys_e, stride_ys_t, stride_ys_g), + dtype=torch.float32, + device=y.device) + + stride_cnt_e = tokens_per_expert.stride()[0] + + # static grid over experts and H-groups. + # A loop inside the kernel handles the token dim + grid = (E * G, ) + + f_info = torch.finfo(fp8_dtype) + fp8_max = f_info.max + fp8_min = -f_info.max + + _silu_mul_fp8_quant_deep_gemm[grid]( + y, + y_q, + y_s, + tokens_per_expert, + E, + T, + H, + group_size, + stride_i_e, + stride_i_t, + stride_i_h, + stride_yq_e, + stride_yq_t, + stride_yq_h, + stride_ys_e, + stride_ys_t, + stride_ys_g, + stride_cnt_e, + eps, + fp8_min, + fp8_max, + BLOCK=group_size, + num_warps=4, + ) + + return y_q, y_s + + class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # The Deep Gemm kernels only support block size of 128 @@ -96,7 +265,6 @@ def apply( hidden_states, w1, w2, topk_ids) workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N)) - workspace2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2)) # (from deepgemm docs) : A value hint (which is a value on CPU) # for the M expectation of each batch, correctly setting this value @@ -109,19 +277,9 @@ def apply( masked_m=expert_num_tokens, expected_m=expected_m) - # TODO (varun) [Optimization]: Use a batched version of activation. - # Similarly for the quant below. - self.activation(activation, workspace2, workspace1.view(-1, N)) - - w2_hidden_size = workspace2.size(-1) - workspace2 = workspace2.view(-1, w2_hidden_size) - - a2q_scale: Optional[torch.Tensor] = None - a2q, a2q_scale = per_token_group_quant_fp8(workspace2, - self.block_shape[1], - column_major_scales=False) - a2q = a2q.view(E, max_num_tokens, -1) - a2q_scale = a2q_scale.view(E, max_num_tokens, -1) + assert expert_num_tokens is not None + a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, + expert_num_tokens) dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale), (w2, w2_scale), From 2b5ad9f233627252860b1639711612a4db8c4554 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 18 Jun 2025 11:15:48 -0700 Subject: [PATCH 3/9] fixes - use-fp8-dispatch Signed-off-by: Varun Sundar Rabindranath --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c6c908f73a2..98733f101ac 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -381,7 +381,7 @@ def init_prepare_finalize(self, moe: MoEConfig, # Note : We may want to use FP8 dispatch even otherwise just to # reduce datamovement use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype() - and act_quant_block_size + and act_quant_block_size[1] == DEEPEP_QUANT_BLOCK_SIZE) # Note (varun): Whether to use FP8 dispatch or not needs some From d5f206767c04c82c30301897be703dbe6ee82939 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 20 Jun 2025 14:39:58 +0000 Subject: [PATCH 4/9] Unit test Signed-off-by: Tyler Michael Smith --- .../moe/test_silu_mul_fp8_quant_deep_gemm.py | 82 +++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py new file mode 100644 index 00000000000..5cfb4266ff9 --- /dev/null +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + silu_mul_fp8_quant_deep_gemm) +from vllm.platforms import current_platform + +# (E, T, H, group_size, seed) +CASES = [ + (1, 1, 128, 64, 0), + (1, 4, 128, 128, 0), + (2, 4, 256, 128, 0), + (32, 64, 256, 128, 0), + (17, 31, 768, 128, 0), +] + + +@pytest.mark.parametrize("E,T,H,group_size,seed", CASES) +@torch.inference_mode() +def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): + current_platform.seed_everything(seed) + + # Input tensor of shape (E, T, 2*H) + y = torch.randn((E, T, 2 * H), dtype=torch.float32, device="cuda") + tokens_per_expert = torch.randint( + low=0, + high=T, + size=(E, ), + dtype=torch.int32, + device="cuda", + ) + + # Run the Triton kernel + y_q, y_s = silu_mul_fp8_quant_deep_gemm(y, + tokens_per_expert, + group_size=group_size, + eps=1e-10) + + # Reference implementation + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max = fp8_info.max + fp8_min = fp8_info.min + eps = 1e-10 + + # Compute silu activation and elementwise multiplication + y1 = y[..., :H] + y2 = y[..., H:] + silu_x = y1 * torch.sigmoid(y1) + merged = silu_x * y2 + + # Compute reference scales and quantized output + ref_s = torch.empty((E, T, H // group_size), + dtype=torch.float32, + device="cuda") + ref_q = torch.empty((E, T, H), dtype=torch.float8_e4m3fn, device="cuda") + # Compute reference scales and quantized output, skipping padded tokens + for e in range(E): + nt = tokens_per_expert[e].item() + for t in range(nt): + data = merged[e, t] + data_grp = data.view(H // group_size, group_size) + amax = data_grp.abs().amax(dim=1).clamp(min=eps) + scale = amax / fp8_max + + scaled = data / scale.repeat_interleave(group_size) + clamped = scaled.clamp(fp8_min, fp8_max) + q = clamped.to(torch.float8_e4m3fn) + + ref_s[e, t] = scale + ref_q[e, t] = q + + # Compare scales and quantized outputs for valid tokens only + for e in range(E): + nt = tokens_per_expert[e].item() + torch.testing.assert_close(y_s[e, :nt], ref_s[e, :nt]) + torch.testing.assert_close( + y_q[e, :nt].to(torch.float32), + ref_q[e, :nt].to(torch.float32), + ) From 26fd8ca33c18f668c95e4328fba735c4881e59c0 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 20 Jun 2025 14:40:21 +0000 Subject: [PATCH 5/9] fixes Signed-off-by: Tyler Michael Smith --- vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 758cd7c56f7..a92125a6fab 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -105,7 +105,7 @@ def silu_mul_fp8_quant_deep_gemm( y: torch.Tensor, # (E, T, 2*H) float32 tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert group_size: int = 128, - eps: float = 1e-6, + eps: float = 1e-10, ): """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales @@ -152,7 +152,7 @@ def silu_mul_fp8_quant_deep_gemm( f_info = torch.finfo(fp8_dtype) fp8_max = f_info.max - fp8_min = -f_info.max + fp8_min = f_info.min _silu_mul_fp8_quant_deep_gemm[grid]( y, From 7a821f0e7f6594496a6b2229e994a2868d0adc7e Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 20 Jun 2025 14:41:20 +0000 Subject: [PATCH 6/9] precommit Signed-off-by: Tyler Michael Smith --- vllm/model_executor/layers/fused_moe/layer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 98733f101ac..4ed10e60b13 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -380,6 +380,7 @@ def init_prepare_finalize(self, moe: MoEConfig, # Note : We may want to use FP8 dispatch even otherwise just to # reduce datamovement + assert act_quant_block_size is not None use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype() and act_quant_block_size[1] == DEEPEP_QUANT_BLOCK_SIZE) From 39d5d33f8f15e823afa9abcb74265c1c474f4563 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 20 Jun 2025 15:36:59 +0000 Subject: [PATCH 7/9] tweaks Signed-off-by: Tyler Michael Smith --- .../layers/fused_moe/batched_deep_gemm_moe.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index a92125a6fab..fae8d3745fe 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -17,14 +17,12 @@ @triton.jit def _silu_mul_fp8_quant_deep_gemm( # Pointers ------------------------------------------------------------ - input_ptr, # *FP32 activations (E, T, 2*H) - y_q_ptr, # *FP8 quantised activations (E, T, H) - y_s_ptr, # *FP32 scales (E, T, G) - counts_ptr, # *INT32 number of tokens per expert (E) + input_ptr, # 16-bit activations (E, T, 2*H) + y_q_ptr, # fp88 quantized activations (E, T, H) + y_s_ptr, # 16-bit scales (E, T, G) + counts_ptr, # int32 num tokens per expert (E) # Sizes --------------------------------------------------------------- - E: tl.constexpr, # num_experts - T: tl.constexpr, # max_num_tokens H: tl.constexpr, # hidden dimension (per output) GROUP_SIZE: tl.constexpr, # elements per group (usually 128) @@ -159,8 +157,6 @@ def silu_mul_fp8_quant_deep_gemm( y_q, y_s, tokens_per_expert, - E, - T, H, group_size, stride_i_e, From 21ffc7353a258d3a72fa77a522222aa0af66ef11 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 20 Jun 2025 15:56:05 +0000 Subject: [PATCH 8/9] fixup Signed-off-by: Tyler Michael Smith --- vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index fae8d3745fe..70836879d17 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -18,9 +18,9 @@ def _silu_mul_fp8_quant_deep_gemm( # Pointers ------------------------------------------------------------ input_ptr, # 16-bit activations (E, T, 2*H) - y_q_ptr, # fp88 quantized activations (E, T, H) + y_q_ptr, # fp8 quantized activations (E, T, H) y_s_ptr, # 16-bit scales (E, T, G) - counts_ptr, # int32 num tokens per expert (E) + counts_ptr, # int32 num tokens per expert (E) # Sizes --------------------------------------------------------------- H: tl.constexpr, # hidden dimension (per output) From b4f17e12a444a90d21528c412d19f6d7488494ba Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 20 Jun 2025 19:47:25 +0000 Subject: [PATCH 9/9] tolerances Signed-off-by: Tyler Michael Smith --- .../moe/test_silu_mul_fp8_quant_deep_gemm.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 5cfb4266ff9..673a0aa3679 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -51,14 +51,13 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): silu_x = y1 * torch.sigmoid(y1) merged = silu_x * y2 - # Compute reference scales and quantized output - ref_s = torch.empty((E, T, H // group_size), - dtype=torch.float32, - device="cuda") - ref_q = torch.empty((E, T, H), dtype=torch.float8_e4m3fn, device="cuda") # Compute reference scales and quantized output, skipping padded tokens for e in range(E): nt = tokens_per_expert[e].item() + ref_s = torch.empty((T, H // group_size), + dtype=torch.float32, + device="cuda") + ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda") for t in range(nt): data = merged[e, t] data_grp = data.view(H // group_size, group_size) @@ -69,14 +68,16 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): clamped = scaled.clamp(fp8_min, fp8_max) q = clamped.to(torch.float8_e4m3fn) - ref_s[e, t] = scale - ref_q[e, t] = q + ref_s[t] = scale + ref_q[t] = q - # Compare scales and quantized outputs for valid tokens only - for e in range(E): - nt = tokens_per_expert[e].item() - torch.testing.assert_close(y_s[e, :nt], ref_s[e, :nt]) + y_se = y_s[e] + y_qe = y_q[e] + + torch.testing.assert_close(y_se[:nt], ref_s[:nt]) torch.testing.assert_close( - y_q[e, :nt].to(torch.float32), - ref_q[e, :nt].to(torch.float32), + y_qe[:nt].to(torch.float32), + ref_q[:nt].to(torch.float32), + atol=2, + rtol=2e-1, )