Skip to content

Conversation

@gnovack
Copy link
Contributor

@gnovack gnovack commented Nov 14, 2025

Purpose

This PR adds support for --fully-sharded-loras to MoE LoRA adapters to allow S-LoRA style sharding of adapter weights. This reduces the amount of GPU memory required per rank when using LoRA w/ tensor parallelism.

Example using Qwen3-30B

Before PR:

vllm serve Qwen/Qwen3-Coder-30B-A3B-Instruct -tp 4 --enable-lora --max-loras 16 --max-lora-rank 32 --fully-sharded-loras
...
[gpu_model_runner.py:3126] Model loading took 53.9761 GiB memory and 11.972274 seconds
...

After PR:

vllm serve Qwen/Qwen3-Coder-30B-A3B-Instruct -tp 4 --enable-lora --max-loras 16 --max-lora-rank 32 --fully-sharded-loras
...
[gpu_model_runner.py:3126] Model loading took 26.9761 GiB memory and 12.145540 seconds
...

Test Plan

  • Added fully_sharded_loras case to existing MoE LoRA TP test cases to validate that the model output remains the
    same when enabling this flag.
  • Added new kernel-level test cases for test_fused_moe_lora_kernel to validate that kernel outputs matches before and after this change

Signed-off-by: gnovack <gnovack@amazon.com>
@gnovack gnovack requested a review from jeejeelee as a code owner November 14, 2025 22:57
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for fully sharded LoRA adapters in FusedMoE, which is a great feature for reducing memory usage with tensor parallelism. The changes in the test files and Punica wrappers look good. However, I've found a critical issue in the implementation within vllm/lora/layers/fused_moe.py that will lead to incorrect results when using fully sharded LoRAs with tensor parallelism. The logic for aggregating the partial LoRA results for the w2 layer is missing an all-reduce operation, and it also attempts to add to an uninitialized tensor. I've provided a detailed comment and a suggested fix for this issue.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

a_intermediate_cache1
)
else:
a_intermediate_cache1 = tensor_model_parallel_all_gather(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Is PDL compatible with fully_sharded?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll test this out today and let you know. my assumption is that PDL probably won't work here due to the collective comms between shrink and expand, but i will confirm this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what i can tell based on profiling results, it seems that PDL does not provide benefit when fully_sharded is enabled, as the communication op between shrink and expand is not overlapped with either one (see below):
Image 11-17-25 at 4 17 PM

However, even when fully_sharded is disabled, it does not look like the moe shrink/expand calls are getting overlapped (seems this is because of the torch.zeros call before the expand kernel):
Image 11-17-25 at 4 19 PM

Let me know if there is some other/better test i can run to see if PDL is working as expected before/after my changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've observed similar behavior as well. At least DPL didn't cause any undefined behavior. I'll take some more time to look into it later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jeejeelee !

Signed-off-by: gnovack <gnovack@amazon.com>
)
for i in range(num_slices):
output[:, :, i * N : (i + 1) * N] += b_intermediate_cache1[i]
output[:, :, i * N + offset : (i + 1) * N + offset] += b_intermediate_cache1[i]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, we can fuse this add into the expand kernel, so we don't need to create b_intermediate_cache1 explicitly.
This should save the overhead of 2 kernels (empty and add). We can complete this in a follow-up PR.

@jeejeelee jeejeelee added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 19, 2025
@jeejeelee jeejeelee merged commit d69062c into vllm-project:main Nov 19, 2025
47 of 48 checks passed
Victor49152 pushed a commit to Victor49152/vllm that referenced this pull request Nov 20, 2025
Signed-off-by: gnovack <gnovack@amazon.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
bhagyashrigai pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Nov 20, 2025
Signed-off-by: gnovack <gnovack@amazon.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Bhagyashri <Bhagyashri.Gaikwad2@ibm.com>
LuminolT pushed a commit to LuminolT/vllm that referenced this pull request Nov 21, 2025
Signed-off-by: gnovack <gnovack@amazon.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: LuminolT <lumischen01@gmail.com>
bigPYJ1151 pushed a commit that referenced this pull request Nov 25, 2025
Signed-off-by: gnovack <gnovack@amazon.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
bringlein pushed a commit to bringlein/vllm that referenced this pull request Nov 26, 2025
Signed-off-by: gnovack <gnovack@amazon.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Signed-off-by: gnovack <gnovack@amazon.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants