Skip to content

[EP+DP] Optimize the little operations in the DeepGEMM + DeepEP low latency case #19885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

tlrmchlsmth
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth commented Jun 20, 2025

Summary

Optimizations when using DeepGEMM + DeepEP (for the decode worker in a P/D setup). On current main, the majority of time is spent in quantize ops, silu-and-mul, and copy operations to put the scales in the layout needed for DeepGEMM.

This PR has two pieces:

  1. Set use_fp8_dispatch to True in DeepEP's dispatch operation. This lets DeepEP take care of the quantization, and puts the scales in the right layout for DeepGEMM.
  2. A triton fused silu-mul-quant kernel that produces tensors with the right shape/strides needed by DeepGEMM. This kernel also avoids computing with or loading padded tokens. The latter is very important since we need a lot of padding for CUDA graph support in this case.

Performance

Running the following pure decode benchmark:

python vllm/benchmarks/benchmark_serving.py \
    --base-url http://vllm-leader:8080 \
    --model deepseek-ai/DeepSeek-R1-0528 \
    --dataset-name random \
    --random-input-len 1 \
    --random-output-len 1000  \
    --request-rate 512 \
    --seed $(date +%M%H%M%S) \
    --num-prompts 512 \
    --ignore-eos

Main:

============ Serving Benchmark Result ============
Successful requests:                     512
Benchmark duration (s):                  333.30
Total input tokens:                      0
Total generated tokens:                  512000
Request throughput (req/s):              1.54
Output token throughput (tok/s):         1536.17
Total Token throughput (tok/s):          1536.17
---------------Time to First Token----------------
Mean TTFT (ms):                          475.77
Median TTFT (ms):                        468.43
P99 TTFT (ms):                           665.10
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          331.71
Median TPOT (ms):                        331.72
P99 TPOT (ms):                           331.75
---------------Inter-token Latency----------------
Mean ITL (ms):                           331.71
Median ITL (ms):                         331.76
P99 ITL (ms):                            345.27
==================================================

This PR:

============ Serving Benchmark Result ============
Successful requests:                     512
Benchmark duration (s):                  138.74
Total input tokens:                      0
Total generated tokens:                  512000
Request throughput (req/s):              3.69
Output token throughput (tok/s):         3690.26
Total Token throughput (tok/s):          3690.26
---------------Time to First Token----------------
Mean TTFT (ms):                          151.18
Median TTFT (ms):                        147.35
P99 TTFT (ms):                           283.58
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          137.54
Median TPOT (ms):                        137.58
P99 TPOT (ms):                           137.66
---------------Inter-token Latency----------------
Mean ITL (ms):                           137.54
Median ITL (ms):                         138.05
P99 ITL (ms):                            152.67
==================================================

Testing

This is covered by tests/kernels/moe/test_deepep_deepgemm_moe.py which remains green. I ran lm_eval gsm8k on Qwen/Qwen3-235B-A22B-FP8 and the score was around 0.9 (I stupidly threw away the result).

TODO: Unit test for the fused act-and-mul-and-quant kernel.

Varun Sundar Rabindranath and others added 4 commits June 18, 2025 07:32
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
- Quantized dispatch
- Fused act-and-mul-and-quant in the right layout for DeepGEMM

Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @tlrmchlsmth, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on critical performance optimizations for the DeepGEMM + DeepEP low-latency path, particularly for decode workers. My changes streamline the data flow and computation by integrating quantization and activation operations directly into the DeepEP dispatch and introducing a highly optimized fused kernel. This significantly reduces overhead from previously identified bottlenecks like quantization, SiLU-and-mul, and data copying, leading to substantial throughput and latency improvements.

Highlights

  • DeepEP Dispatch Optimization: I've enabled use_fp8_dispatch in DeepEP's dispatch operation. This allows DeepEP to handle the quantization process directly, ensuring that the scales are in the correct layout for DeepGEMM and reducing the need for separate copy operations.
  • Fused SiLU-Mul-Quant Kernel: I've introduced a new Triton-based fused silu-mul-quant kernel. This kernel efficiently performs the SiLU activation, multiplication, and FP8 quantization in a single operation. It's designed to produce tensors with the precise shape and strides required by DeepGEMM and is optimized to avoid processing or loading padded tokens, which is critical for CUDA graph support in low-latency scenarios.
  • Performance Improvement: These optimizations have led to significant performance gains in the pure decode benchmark. Request throughput and output token throughput have both improved by approximately 2.4x, while Time to First Token (TTFT) is about 3.1x faster and Time per Output Token (TPOT) is about 2.4x faster.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the qwen Related to Qwen models label Jun 20, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces optimizations for the DeepGEMM + DeepEP low latency case, including setting use_fp8_dispatch to True in DeepEP's dispatch operation and a triton fused silu-mul-quant kernel. The changes significantly improve performance, as demonstrated by the benchmark results. The code is well-structured and includes necessary assertions. However, some comments could be added to improve readability and explain certain design choices.

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean implementation, thanks for the summary! I think we can land with fixed pre-commit. It would be nice to have a kernel test comparing this fused version versus the previous version, but that may be overkill.


f_info = torch.finfo(fp8_dtype)
fp8_max = f_info.max
fp8_min = -f_info.max
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: might as well use f_info.min

@mgoin mgoin added deepseek Related to DeepSeek models performance Performance-related issues labels Jun 20, 2025
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
@tlrmchlsmth tlrmchlsmth requested a review from WoosukKwon as a code owner June 20, 2025 14:41
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) June 20, 2025 15:56
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 20, 2025
Signed-off-by: Tyler Michael Smith <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
deepseek Related to DeepSeek models performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants