Skip to content

[EP+DP] Optimize the little operations in the DeepGEMM + DeepEP low latency case #19885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
silu_mul_fp8_quant_deep_gemm)
from vllm.platforms import current_platform

# (E, T, H, group_size, seed)
CASES = [
(1, 1, 128, 64, 0),
(1, 4, 128, 128, 0),
(2, 4, 256, 128, 0),
(32, 64, 256, 128, 0),
(17, 31, 768, 128, 0),
]


@pytest.mark.parametrize("E,T,H,group_size,seed", CASES)
@torch.inference_mode()
def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
current_platform.seed_everything(seed)

# Input tensor of shape (E, T, 2*H)
y = torch.randn((E, T, 2 * H), dtype=torch.float32, device="cuda")
tokens_per_expert = torch.randint(
low=0,
high=T,
size=(E, ),
dtype=torch.int32,
device="cuda",
)

# Run the Triton kernel
y_q, y_s = silu_mul_fp8_quant_deep_gemm(y,
tokens_per_expert,
group_size=group_size,
eps=1e-10)

# Reference implementation
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max = fp8_info.max
fp8_min = fp8_info.min
eps = 1e-10

# Compute silu activation and elementwise multiplication
y1 = y[..., :H]
y2 = y[..., H:]
silu_x = y1 * torch.sigmoid(y1)
merged = silu_x * y2

# Compute reference scales and quantized output, skipping padded tokens
for e in range(E):
nt = tokens_per_expert[e].item()
ref_s = torch.empty((T, H // group_size),
dtype=torch.float32,
device="cuda")
ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda")
for t in range(nt):
data = merged[e, t]
data_grp = data.view(H // group_size, group_size)
amax = data_grp.abs().amax(dim=1).clamp(min=eps)
scale = amax / fp8_max

scaled = data / scale.repeat_interleave(group_size)
clamped = scaled.clamp(fp8_min, fp8_max)
q = clamped.to(torch.float8_e4m3fn)

ref_s[t] = scale
ref_q[t] = q

y_se = y_s[e]
y_qe = y_q[e]

torch.testing.assert_close(y_se[:nt], ref_s[:nt])
torch.testing.assert_close(
y_qe[:nt].to(torch.float32),
ref_q[:nt].to(torch.float32),
atol=2,
rtol=2e-1,
)
186 changes: 170 additions & 16 deletions vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,179 @@

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, per_token_group_quant_fp8)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton

logger = init_logger(__name__)

has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None


@triton.jit
def _silu_mul_fp8_quant_deep_gemm(
# Pointers ------------------------------------------------------------
input_ptr, # 16-bit activations (E, T, 2*H)
y_q_ptr, # fp8 quantized activations (E, T, H)
y_s_ptr, # 16-bit scales (E, T, G)
counts_ptr, # int32 num tokens per expert (E)

# Sizes ---------------------------------------------------------------
H: tl.constexpr, # hidden dimension (per output)
GROUP_SIZE: tl.constexpr, # elements per group (usually 128)

# Strides for input (elements) ---------------------------------------
stride_i_e,
stride_i_t,
stride_i_h,

# Strides for y_q (elements) -----------------------------------------
stride_yq_e,
stride_yq_t,
stride_yq_h,

# Strides for y_s (elements) -----------------------------------------
stride_ys_e,
stride_ys_t,
stride_ys_g,

# Stride for counts (elements)
stride_counts_e,

# Numeric params ------------------------------------------------------
eps: tl.constexpr,
fp8_min: tl.constexpr,
fp8_max: tl.constexpr,

# Meta ---------------------------------------------------------------
BLOCK: tl.constexpr,
):
G = H // GROUP_SIZE

# map program id -> (e, g)
pid = tl.program_id(0)
e = pid // G
g = pid % G

e = e.to(tl.int64)
g = g.to(tl.int64)

# number of valid tokens for this expert
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64)

cols = tl.arange(0, BLOCK)
cols = cols.to(tl.int64)
mask_h = cols < BLOCK

t = tl.zeros([], tl.int64)
while t < n_tokens:
base_i_offset = (e * stride_i_e + t * stride_i_t +
g * GROUP_SIZE * stride_i_h)
base_yq_offset = (e * stride_yq_e + t * stride_yq_t +
g * GROUP_SIZE * stride_yq_h)
base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g

mask = mask_h
x = tl.load(input_ptr + base_i_offset + cols * stride_i_h,
mask=mask,
other=0.0).to(tl.float32)
y2 = tl.load(input_ptr + base_i_offset + H * stride_i_h +
cols * stride_i_h,
mask=mask,
other=0.0).to(tl.float32)

x = x * (1.0 / (1.0 + tl.exp(-x)))
y = x * y2

_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)

tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask)
tl.store(y_s_ptr + base_ys_offset, y_s)

t += 1


def silu_mul_fp8_quant_deep_gemm(
y: torch.Tensor, # (E, T, 2*H) float32
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
group_size: int = 128,
eps: float = 1e-10,
):
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales

y has shape (E, T, 2*H). The first half of the last dimension is
silu-activated, multiplied by the second half, then quantized into FP8.

Returns `(y_q, y_s)` where
* `y_q` is the FP8 tensor of shape `(E, T, H)`, same layout as `y[..., :H]`.
* `y_s` has shape `(E, T, H // group_size)` and strides `(T*G, 1, T)`
"""
assert y.ndim == 3, "y must be (E, T, 2*H)"
E, T, H2 = y.shape
assert H2 % 2 == 0, "last dim of y must be even (2*H)"
H = H2 // 2
G = H // group_size
assert H % group_size == 0, "H must be divisible by group_size"
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, \
"tokens_per_expert must be shape (E,)"
tokens_per_expert = tokens_per_expert.to(device=y.device,
dtype=torch.int32)

# allocate outputs
fp8_dtype = torch.float8_e4m3fn
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)

# strides (elements)
stride_i_e, stride_i_t, stride_i_h = y.stride()
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()

# desired scale strides (elements): (T*G, 1, T)
stride_ys_e = T * G
stride_ys_t = 1
stride_ys_g = T
y_s = torch.empty_strided((E, T, G),
(stride_ys_e, stride_ys_t, stride_ys_g),
dtype=torch.float32,
device=y.device)

stride_cnt_e = tokens_per_expert.stride()[0]

# static grid over experts and H-groups.
# A loop inside the kernel handles the token dim
grid = (E * G, )

f_info = torch.finfo(fp8_dtype)
fp8_max = f_info.max
fp8_min = f_info.min

_silu_mul_fp8_quant_deep_gemm[grid](
y,
y_q,
y_s,
tokens_per_expert,
H,
group_size,
stride_i_e,
stride_i_t,
stride_i_h,
stride_yq_e,
stride_yq_t,
stride_yq_h,
stride_ys_e,
stride_ys_t,
stride_ys_g,
stride_cnt_e,
eps,
fp8_min,
fp8_max,
BLOCK=group_size,
num_warps=4,
)

return y_q, y_s


class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):

# The Deep Gemm kernels only support block size of 128
Expand Down Expand Up @@ -96,7 +261,6 @@ def apply(
hidden_states, w1, w2, topk_ids)

workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
workspace2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2))

# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
Expand All @@ -109,19 +273,9 @@ def apply(
masked_m=expert_num_tokens,
expected_m=expected_m)

# TODO (varun) [Optimization]: Use a batched version of activation.
# Similarly for the quant below.
self.activation(activation, workspace2, workspace1.view(-1, N))

w2_hidden_size = workspace2.size(-1)
workspace2 = workspace2.view(-1, w2_hidden_size)

a2q_scale: Optional[torch.Tensor] = None
a2q, a2q_scale = per_token_group_quant_fp8(workspace2,
self.block_shape[1],
column_major_scales=False)
a2q = a2q.view(E, max_num_tokens, -1)
a2q_scale = a2q_scale.view(E, max_num_tokens, -1)
assert expert_num_tokens is not None
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
expert_num_tokens)

dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale),
(w2, w2_scale),
Expand Down
12 changes: 10 additions & 2 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
from .pplx_prepare_finalize import PplxPrepareAndFinalize
if has_deepep:
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import DeepEPLLPrepareAndFinalize
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
DeepEPLLPrepareAndFinalize)
else:
fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore
Expand Down Expand Up @@ -377,6 +378,13 @@ def init_prepare_finalize(self, moe: MoEConfig,
all2all_manager.world_size)
handle = all2all_manager.get_handle(all_to_all_args)

# Note : We may want to use FP8 dispatch even otherwise just to
# reduce datamovement
assert act_quant_block_size is not None
use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype()
and act_quant_block_size[1]
== DEEPEP_QUANT_BLOCK_SIZE)

# Note (varun): Whether to use FP8 dispatch or not needs some
# profiling. Turning it off for now.
prepare_finalize = DeepEPLLPrepareAndFinalize(
Expand All @@ -386,7 +394,7 @@ def init_prepare_finalize(self, moe: MoEConfig,
max_tokens_per_rank=moe.max_num_tokens,
quant_dtype=quant_dtype,
block_shape=act_quant_block_size,
use_fp8_dispatch=False,
use_fp8_dispatch=use_fp8_dispatch,
)

self.topk_indices_dtype = None
Expand Down