Skip to content

Conversation

@pavanimajety
Copy link
Collaborator

@pavanimajety pavanimajety commented Oct 23, 2025

Purpose

This PR switches the default MOE backend to use Flashinfer TRTLLM MOE kernels which are optimized for the latency scenarios.

Additionally, I address a few more issues -

  1. Move the zero initialization for fp4 quantization in the padded scenarios to kernel to avoid extra kernel call for the whole tensor (introduced in [Bugfix] Enable padded FP4 quantization #25947.)
  2. Fix the k_scale and v_scale loading again!

Fixes: #26070

Test Plan

Test Nemotron for accuracy and llama 3 70B FP4 for accuracy.

Test Result

nvidia/NVIDIA-Nemotron-Nano-9B-v2:
Original:

TP 2:
Evaluating: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [01:31<00:00, 14.38it/s]
Results: Original (zero initialized with torch.zeros)
Accuracy: 0.794
Invalid responses: 0.000
Total latency: 91.735 s
Questions per second: 14.378

Results: PR (zero intialized in the kernel)

root@gb-nvl-082-compute08:/workspace/scratch-pmaj-1/arm-gh-pm-vllm# python3 tests/evals/gsm8k/gsm8k_eval.py 
Downloading from https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl to /tmp/train.jsonl
Downloading from https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl to /tmp/test.jsonl
Running GSM8K evaluation: 1319 questions, 5-shot
Evaluating: 100%|█████████████████████████████████████████████████████████████████| 1319/1319 [01:31<00:00, 14.44it/s]

Results:
Accuracy: 0.792
Invalid responses: 0.000
Total latency: 91.362 s
Questions per second: 14.437

Llama 70B FP4 + TP2

original (torch.zeros)

root@gb-nvl-082-compute08:/workspace/scratch-pmaj-1/arm-gh-pm-vllm# python3 tests/evals/gsm8k/gsm8k_eval.py 
Running GSM8K evaluation: 1319 questions, 5-shot
Evaluating: 100%|███████████████████████████████████████████████████████████████████████████| 1319/1319 [00:25<00:00, 52.12it/s]

Results:
Accuracy: 0.923
Invalid responses: 0.001
Total latency: 25.318 s
Questions per second: 52.09

With PR:

root@gb-nvl-082-compute08:/workspace/scratch-pmaj-1/arm-gh-pm-vllm# python3 tests/evals/gsm8k/gsm8k_eval.py 
Running GSM8K evaluation: 1319 questions, 5-shot
Evaluating: 100%|███████████████████████████████████████████████████████████████████████████| 1319/1319 [00:25<00:00, 52.68it/s]

Results:
Accuracy: 0.923
Invalid responses: 0.001
Total latency: 25.052 s
Questions per second: 52.651

Performance

Server:

vllm serve --host 0.0.0.0 --port 8087 --model nvidia/Llama-3.3-70B-Instruct-FP4  --dtype auto --kv-cache-dtype fp8 --tensor-parallel-size 4 --pipeline-parallel-size 1 --data-parallel-size 1 --swap-space 16   --max-num-seqs 512 --trust-remote-code --max-model-len 2058 --gpu-memory-utilization 0.9 --max-num-batched-tokens 8192 --no-enable-prefix-caching --async-scheduling  --compilation_config.pass_config.enable_fi_allreduce_fusion true --compilation_config.pass_config.enable_attn_fusion true --compilation_config.pass_config.enable_noop true  --compilation_config.custom_ops+=+quant_fp8,+rms_norm --compilation_config.cudagraph_mode FULL_DECODE_ONLY --compilation_config.splitting_ops []

With zeros TOT main:

Maximum request concurrency: 8
100%|██████████████████████████████████████████████████████| 40/40 [00:28<00:00,  1.39it/s]
tip: install termplotlib and gnuplot to plot the metrics
============ Serving Benchmark Result ============
Successful requests:                     40        
Failed requests:                         0         
Maximum request concurrency:             8         
Benchmark duration (s):                  28.78     
Total input tokens:                      40920     
Total generated tokens:                  24008     
Request throughput (req/s):              1.39      
Output token throughput (tok/s):         834.33    
Peak output token throughput (tok/s):    1067.00   
Peak concurrent requests:                13.00     
Total Token throughput (tok/s):          2256.39   
---------------Time to First Token----------------
Mean TTFT (ms):                          129.44    
Median TTFT (ms):                        93.95     
P99 TTFT (ms):                           273.25    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          8.05      
Median TPOT (ms):                        8.03      
P99 TPOT (ms):                           8.86      
---------------Inter-token Latency----------------
Mean ITL (ms):                           8.10      
Median ITL (ms):                         7.50      
P99 ITL (ms):                            8.60      
==================================================

with torch.empty + TOT main

Maximum request concurrency: 8
100%|██████████████████████████████████████████████████████| 40/40 [00:27<00:00,  1.44it/s]
tip: install termplotlib and gnuplot to plot the metrics
============ Serving Benchmark Result ============
Successful requests:                     40        
Failed requests:                         0         
Maximum request concurrency:             8         
Benchmark duration (s):                  27.70     
Total input tokens:                      40920     
Total generated tokens:                  24636     
Request throughput (req/s):              1.44      
Output token throughput (tok/s):         889.52    
Peak output token throughput (tok/s):    1120.00   
Peak concurrent requests:                12.00     
Total Token throughput (tok/s):          2367.01   
---------------Time to First Token----------------
Mean TTFT (ms):                          126.30    
Median TTFT (ms):                        88.52     
P99 TTFT (ms):                           259.42    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.74      
Median TPOT (ms):                        7.76      
P99 TPOT (ms):                           9.49      
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.69      
Median ITL (ms):                         7.14      
P99 ITL (ms):                            8.13      

PR with fix:

============ Serving Benchmark Result ============
Successful requests:                     40        
Failed requests:                         0         
Maximum request concurrency:             8         
Benchmark duration (s):                  26.08     
Total input tokens:                      40920     
Total generated tokens:                  22789     
Request throughput (req/s):              1.53      
Output token throughput (tok/s):         873.75    
Peak output token throughput (tok/s):    1128.00   
Peak concurrent requests:                13.00     
Total Token throughput (tok/s):          2442.67   
---------------Time to First Token----------------
Mean TTFT (ms):                          126.60    
Median TTFT (ms):                        94.03     
P99 TTFT (ms):                           283.60    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.67      
Median TPOT (ms):                        7.69      
P99 TPOT (ms):                           8.96      
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.67      
Median ITL (ms):                         7.08      
P99 ITL (ms):                            8.09      
</details>


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@pavanimajety pavanimajety force-pushed the flashinfer-latency-moe-pick branch from b33adab to cd262df Compare October 27, 2025 17:05
@pavanimajety pavanimajety added the bug Something isn't working label Oct 27, 2025
@pavanimajety pavanimajety marked this pull request as ready for review October 27, 2025 21:08
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

vllm/vllm/envs.py

Lines 1171 to 1180 in 34d8036

# Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support.
# Both require compute capability 10.0 or above.
# Available options:
# - "throughput": [default]
# Uses CUTLASS kernels optimized for high-throughput batch inference.
# - "latency":
# Uses TensorRT-LLM kernels optimized for low-latency inference.
"VLLM_FLASHINFER_MOE_BACKEND": env_with_choices(
"VLLM_FLASHINFER_MOE_BACKEND", "throughput", ["throughput", "latency"]
),

P1 Badge FlashInfer MoE backend default not updated

The commit message claims to switch the default FlashInfer MoE backend to the latency‑optimised kernels, but the actual environment configuration still declares env_with_choices("VLLM_FLASHINFER_MOE_BACKEND", "throughput", ...). Only the type hint at the top of the file was changed, so the runtime default remains "throughput" and the code never adopts the intended latency backend unless the user sets the variable manually.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@pavanimajety
Copy link
Collaborator Author

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Copy link
Collaborator

@bnellnm bnellnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@pavanimajety pavanimajety added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 31, 2025

namespace vllm {

#define round_up(x, y) ((x + y - 1) / y * y)
Copy link
Contributor

@nvjullin nvjullin Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: macro name convention is ALL_CAPS_WITH_UNDERSCORE. But better yet, don't use a macro and use

template<typename Int>
__host__ __device__ static Int round_up(Int x, Int y)
{
    static_assert(std::is_integral_v<Int>, "round_up argument must be integral type");
    return (x + y - 1) / y * y;
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, fixed in 0c22d3c

@pavanimajety pavanimajety force-pushed the flashinfer-latency-moe-pick branch from 0ff1f63 to 0c22d3c Compare November 3, 2025 23:21
@pavanimajety
Copy link
Collaborator Author

The failed lm-eval-small-models test passes locally-

Log for test_gsm8k_correctness

# pytest tests/evals/gsm8k/test_gsm8k_correctness.py::test_gsm8k_correctness_param[Qwen1.5-MoE-W4A16-CT-tp1]
================================================================================================================================================= test session starts =================================================================================================================================================
platform linux -- Python 3.12.12, pytest-8.4.2, pluggy-1.6.0
rootdir: /workspace/pm-vllm
configfile: pyproject.toml
plugins: asyncio-1.2.0, anyio-4.11.0
asyncio: mode=Mode.STRICT, debug=False, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 1 item                                                                                                                                                                                                                                                                                                      

tests/evals/gsm8k/test_gsm8k_correctness.py .                                                                                                                                                                                                                                          [100%]

====================================================================================================================================== warnings summary ======================================================================================================================================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

../../usr/local/lib/python3.12/dist-packages/_pytest/cacheprovider.py:475
  /usr/local/lib/python3.12/dist-packages/_pytest/cacheprovider.py:475: PytestCacheWarning: could not create cache path /workspace/pm-vllm/.pytest_cache/v/cache/nodeids: [Errno 13] Permission denied: '/workspace/pm-vllm/pytest-cache-files-txppmog7'
    config.cache.set("cache/nodeids", sorted(self.cached_nodeids))

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================================================================================= 1 passed, 3 warnings in 224.90s (0:03:44) ==========================================================================================================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribut

`test_response_api_mcp_tools.py::test_mcp_tool_env_flag_enabled[openai/gpt-oss-20b]`

--
The other failures seem unrelated or failing on main as well

@pavanimajety
Copy link
Collaborator Author

Test failures are unrelated to the PR

Comment on lines 1219 to 1221
"VLLM_FLASHINFER_MOE_BACKEND": env_with_choices(
"VLLM_FLASHINFER_MOE_BACKEND", "throughput", ["throughput", "latency"]
"VLLM_FLASHINFER_MOE_BACKEND", "latency", ["throughput", "latency"]
),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it the case that both fp8 and nvfp4 throughput won't be affected by this? I see you tested for nvfp4, but this will affect several quantized moe cases

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, we see good perf with trtllm kernels across the board. We also have this [PR]([feat] Refactor trtllmgen MOE and add Bf16 trtllmgen moe by jiahanc · Pull Request #2014 · flashinfer-ai/flashinfer) from @jiahanc that closes more gaps. Having this default also enables Expert Parallel in the default path for FP8 in addition to performance

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
@pavanimajety pavanimajety force-pushed the flashinfer-latency-moe-pick branch from 26f3005 to 903430c Compare November 6, 2025 18:54
@vllm-bot vllm-bot merged commit 72b1c2a into vllm-project:main Nov 7, 2025
87 of 91 checks passed
ZhengHongming888 pushed a commit to ZhengHongming888/vllm that referenced this pull request Nov 8, 2025
…misc fixes (vllm-project#27439)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Nov 13, 2025
…misc fixes (vllm-project#27439)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
…misc fixes (vllm-project#27439)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: DSR1 FP4 + DEP8 on B200 fails with TensorRT-LLM throughput kernels

5 participants