Skip to content

[Feature] Integrate new deepgemm #19820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

yewentao256
Copy link
Contributor

@yewentao256 yewentao256 commented Jun 18, 2025

Purpose

DeepGemm is updating to v2.0 deepseek-ai/DeepGEMM#112

All of the API interfaces change and we migrate to their new version.

Test

On B200:

deep-gemm:

python benchmark_moe.py --use-deep-gemm --dtype fp8_w8a8
Batch size: 1, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 113.39 us
Batch size: 2, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 142.13 us
Batch size: 4, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 181.18 us
Batch size: 8, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 223.60 us
Batch size: 16, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 702.07 us
Batch size: 24, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 705.23 us
Batch size: 32, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 708.10 us
Batch size: 48, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 711.43 us
Batch size: 64, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 713.10 us
Batch size: 96, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 725.92 us
Batch size: 128, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 464.54 us
Batch size: 256, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 494.35 us
Batch size: 512, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 579.40 us
Batch size: 1024, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 764.41 us
Batch size: 1536, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 947.47 us
Batch size: 2048, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 1193.57 us
Batch size: 3072, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 1567.98 us
Batch size: 4096, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 1949.95 us

Normal:

python benchmark_moe.py --dtype fp8_w8a8
INFO 06-18 16:54:45 [__init__.py:244] Automatically detected platform cuda.
Namespace(model='mistralai/Mixtral-8x7B-Instruct-v0.1', tp_size=2, dtype='fp8_w8a8', use_deep_gemm=False, seed=0, batch_size=None, tune=False, trust_remote_code=False, model_prefix=None)
2025-06-18 16:54:48,239 INFO worker.py:1888 -- Started a local Ray instance.
(pid=2214862) INFO 06-18 16:54:52 [__init__.py:244] Automatically detected platform cuda.
(BenchmarkWorker pid=2214862) WARNING 06-18 16:54:54 [fused_moe.py:683] Using default MoE config. Performance might be sub-optimal! Config file not found at /home/wentao/vllm-source/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8.json
Batch size: 1, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 106.07 us
Batch size: 2, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 132.09 us
Batch size: 4, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 163.15 us
Batch size: 8, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 187.90 us
Batch size: 16, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 233.51 us
Batch size: 24, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 236.29 us
Batch size: 32, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 237.42 us
Batch size: 48, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 240.64 us
Batch size: 64, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 243.90 us
Batch size: 96, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 258.02 us
Batch size: 128, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 265.80 us
Batch size: 256, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 393.16 us
Batch size: 512, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 578.31 us
Batch size: 1024, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 964.15 us
Batch size: 1536, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 1352.90 us
Batch size: 2048, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 1744.83 us
Batch size: 3072, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 2534.58 us
Batch size: 4096, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 3310.33 us

image

Signed-off-by: yewentao256 <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @yewentao256, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on integrating the new v2.0 API of the DeepGEMM library into the codebase. This migration was necessary due to significant interface changes in DeepGEMM. The changes span across core quantization layers, MoE implementations, and associated benchmark and test suites, ensuring that vLLM remains compatible and functional with the updated DeepGEMM version while maintaining performance and correctness.

Highlights

  • DeepGEMM API Migration: The core DeepGEMM integration has been updated to align with its v2.0 API. This includes renaming of several key GEMM functions (e.g., gemm_fp8_fp8_bf16_nt to fp8_gemm_nt, and various m_grouped_gemm functions to their fp8_m_grouped counterparts) and adjusting import paths for DeepGEMM's internal utility functions.
  • Utility Function Integration: Several utility functions previously used or implicitly available through DeepGEMM (like per_token_cast_to_fp8, per_block_cast_to_fp8, ceil_div, get_col_major_tma_aligned_tensor) are now explicitly imported from DeepGEMM's utils.math module or copied directly into vLLM's fp8_utils.py to ensure continued functionality and compatibility.
  • Benchmark and Test Suite Updates: The benchmark_fp8_block_dense_gemm.py script and various kernel test files (test_deepep_deepgemm_moe.py, test_block_fp8.py) have been updated to reflect the new DeepGEMM API calls and tensor handling, ensuring correctness and providing performance comparisons with the updated library.
  • MoE Layer Adjustments: The DeepGEMM-specific Mixture-of-Experts (MoE) layers (batched_deep_gemm_moe.py, deep_gemm_moe.py) have been modified to use the new DeepGEMM API for grouped and masked GEMM operations, maintaining the functionality of FP8 MoE computations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates DeepGEMM v2.0, which involves updating API interfaces across several files. The changes primarily consist of replacing old DeepGEMM function calls with their new counterparts and adjusting related logic, such as tensor initializations and the handling of helper functions for quantization and tensor alignment.

Key changes include:

  • Updated import statements and function calls to reflect DeepGEMM v2.0 API.
  • Removal of some locally copied DeepGEMM helper functions, preferring the versions from the deep_gemm library itself.
  • Adjustments in benchmark and test files to align with the new API, including changes to tensor initialization and the removal of explicit tensor alignment steps in some cases, presumably handled by the new API or different expectations for input formats.
  • Introduction of get_tma_aligned_size and get_col_major_tma_aligned_tensor functions into vllm/model_executor/layers/quantization/utils/fp8_utils.py as temporary measures, with a TODO to remove them once exposed by DeepGEMM.

The changes appear to be a careful migration. The main point for future improvement is addressing the TODO for the duplicated functions.

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.



# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
def get_tma_aligned_size(x: int, element_size: int) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add some unittest for these utility functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a temporary function, we will delete this when DeepGemm exposes the function, and I believe it is well tested on their side. So perhaps we don't need unit test for this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree a unittest is likely not needed since this is copied from deepgemm

@houseroad houseroad added ready ONLY add when PR is ready to merge/full CI is needed deepseek Related to DeepSeek models labels Jun 19, 2025
Signed-off-by: yewentao256 <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, great work @yewentao256! To confirm, the interface change will be the same for Hopper and Blackwell? If so I think we should wait to land this until the DG PR lands. Also can you show an e2e perf improvement for DeepSeek?



# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
def get_tma_aligned_size(x: int, element_size: int) -> int:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree a unittest is likely not needed since this is copied from deepgemm

@yewentao256
Copy link
Contributor Author

yewentao256 commented Jun 20, 2025

Looks good to me, great work @yewentao256! To confirm, the interface change will be the same for Hopper and Blackwell? If so I think we should wait to land this until the DG PR lands. Also can you show an e2e perf improvement for DeepSeek?

Yes same for the two.

I did more e2e benchmark, and found something strange:

VLLM_USE_DEEP_GEMM=1 vllm bench throughput  --model Qwen/Qwen3-30B-A3B-FP8 --load-format dummy --input-len 1000 --output-len 100 --trust_remote_code --enforce-eager --enable-expert-parallel --quantization fp8
Throughput: 26.34 requests/s, 28916.28 total tokens/s, 2634.07 output tokens/s
vllm bench throughput  --model Qwen/Qwen3-30B-A3B-FP8 --load-format dummy --input-len 1000 --output-len 100 --trust_remote_code --enforce-eager --enable-expert-parallel --quantization fp8
Throughput: 36.65 requests/s, 40270.79 total tokens/s, 3665.06 output tokens/s
VLLM_USE_DEEP_GEMM=1 vllm bench throughput  --model deepseek-ai/DeepSeek-R1 --load-format dummy --input-len 32 --output-len 128 --trust_remote_code --enforce-eager -tp 8 --enable-expert-parallel  --no-enable-prefix-caching
Throughput: 23.89 requests/s, 3821.89 total tokens/s, 3058.29 output tokens/s
# NO deepgemm
Throughput: 42.59 requests/s, 6811.08 total tokens/s, 5451.01 output tokens/s

I guess it suffers from JIT

So benchmarking with latency (warm-up):

VLLM_USE_DEEP_GEMM=1 vllm bench latency  --model deepseek-ai/DeepSeek-R1  --load-format dummy  --input-len 32 --output-len 128 --batch-size 8   --num-iters-warmup 2  --num-iters 5 --trust_remote_code --enforce-eager -tp 8 --enable-expert-parallel  --no-enable-prefix-caching
Avg latency: 14.85632159441011 seconds
10% percentile latency: 14.821913873415905 seconds
25% percentile latency: 14.843483201984782 seconds
50% percentile latency: 14.846536317025311 seconds
75% percentile latency: 14.886424494034145 seconds
90% percentile latency: 14.893147580395453 seconds
99% percentile latency: 14.897181432212237 seconds
# No deepgemm
Avg latency: 14.617215584986843 seconds
10% percentile latency: 14.35443497018423 seconds
25% percentile latency: 14.620368588948622 seconds
50% percentile latency: 14.657322109036613 seconds
75% percentile latency: 14.802983859961387 seconds
90% percentile latency: 14.818148029572331 seconds
99% percentile latency: 14.827246531338897 seconds

Seems the same now.

@yewentao256
Copy link
Contributor Author

Add lm_eval test result:

lm_eval --model vllm --model_args pretrained=Qwen/Qwen3-30B-A3B-FP8,max_model_len=32768 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.8317|±  |0.0103|
|     |       |strict-match    |     5|exact_match||0.8886|±  |0.0087|

VLLM_USE_DEEP_GEMM=1 lm_eval --model vllm --model_args pretrained=Qwen/Qwen3-30B-A3B-FP8,max_model_len=32768 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto --force_eager
vllm (pretrained=Qwen/Qwen3-30B-A3B-FP8,enforce_eager=True,max_model_len=32768,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.8309|±  |0.0103|
|     |       |strict-match    |     5|exact_match||0.8961|±  |0.0084|

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
deepseek Related to DeepSeek models performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants