From db301e3ad506573c2b4d9996875d5c7f459eb762 Mon Sep 17 00:00:00 2001 From: gnovack Date: Fri, 14 Nov 2025 22:42:33 +0000 Subject: [PATCH 1/2] add support for --fully-sharded-loras in fused_moe Signed-off-by: gnovack --- tests/lora/test_olmoe_tp.py | 10 ++++-- vllm/lora/layers/fused_moe.py | 36 ++++++++++++++++--- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 21 +++++++++-- vllm/lora/punica_wrapper/punica_base.py | 2 ++ vllm/lora/punica_wrapper/punica_gpu.py | 4 +++ 5 files changed, 65 insertions(+), 8 deletions(-) diff --git a/tests/lora/test_olmoe_tp.py b/tests/lora/test_olmoe_tp.py index e659c1e1a9a0..e3c9816625ba 100644 --- a/tests/lora/test_olmoe_tp.py +++ b/tests/lora/test_olmoe_tp.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + import vllm from vllm.lora.request import LoRARequest @@ -111,8 +113,9 @@ def test_olmoe_lora_mixed(olmoe_lora_files): generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None]) +@pytest.mark.parametrize("fully_sharded_loras", [False, True]) @multi_gpu_test(num_gpus=2) -def test_olmoe_lora_tp2(olmoe_lora_files): +def test_olmoe_lora_tp2(olmoe_lora_files, fully_sharded_loras): llm = vllm.LLM( MODEL_PATH, max_model_len=1024, @@ -122,14 +125,16 @@ def test_olmoe_lora_tp2(olmoe_lora_files): trust_remote_code=True, enable_chunked_prefill=True, tensor_parallel_size=2, + fully_sharded_loras=fully_sharded_loras, ) generate_and_test(llm, olmoe_lora_files, lora_id=1) generate_and_test(llm, olmoe_lora_files, lora_id=2) +@pytest.mark.parametrize("fully_sharded_loras", [False, True]) @multi_gpu_test(num_gpus=4) -def test_olmoe_lora_tp4(olmoe_lora_files): +def test_olmoe_lora_tp4(olmoe_lora_files, fully_sharded_loras): llm = vllm.LLM( MODEL_PATH, max_model_len=1024, @@ -139,6 +144,7 @@ def test_olmoe_lora_tp4(olmoe_lora_files): trust_remote_code=True, enable_chunked_prefill=True, tensor_parallel_size=4, + fully_sharded_loras=fully_sharded_loras, ) generate_and_test(llm, olmoe_lora_files, lora_id=1) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 8fb3efa220f6..dc87950106d5 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -12,6 +12,7 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) +from vllm.distributed.utils import divide from vllm.lora.layers.base import BaseLayerWithLoRA from vllm.lora.ops.triton_ops.utils import get_lora_op_configs from vllm.model_executor.layers.fused_moe import FusedMoE @@ -205,6 +206,7 @@ def wrapper(*args, **kwargs): shrink_config, ## pass the shrink config expand_config, ## pass the expand config self.adapter_enabled, + fully_sharded=self.fully_sharded, ) result = func(*args, **kwargs) @@ -250,7 +252,10 @@ def wrapper(*args, **kwargs): sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1) intermediate_cache2 = moe_state_dict["intermediate_cache2"] intermediate_cache3 = args[0] - max_lora_rank = self.w1_lora_a_stacked.shape[-2] + max_lora_rank = self.w2_lora_b_stacked.shape[-1] + + shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size) + self.punica_wrapper.add_lora_fused_moe( intermediate_cache3, intermediate_cache2, @@ -266,6 +271,8 @@ def wrapper(*args, **kwargs): expand_config, ## pass the expand config self.adapter_enabled, True, + fully_sharded=self.fully_sharded, + offset=shard_size_w2 * self.tp_rank if self.fully_sharded else 0, ) result = func(*args, **kwargs) @@ -294,6 +301,7 @@ def create_lora_weights( model_config: PretrainedConfig | None = None, ) -> None: """Initializes lora matrices.""" + self.fully_sharded = lora_config.fully_sharded_loras self.adapter_enabled = torch.tensor( [0] * (max_loras + 1), dtype=torch.int, device=self.device @@ -303,7 +311,9 @@ def create_lora_weights( ( max_loras, self.base_layer.local_num_experts, - lora_config.max_lora_rank, + lora_config.max_lora_rank + if not self.fully_sharded + else divide(lora_config.max_lora_rank, self.tp_size), self.base_layer.hidden_size, ), dtype=lora_config.lora_dtype, @@ -334,7 +344,9 @@ def create_lora_weights( ( max_loras, self.base_layer.local_num_experts, - self.base_layer.hidden_size, + self.base_layer.hidden_size + if not self.fully_sharded + else divide(self.base_layer.hidden_size, self.tp_size), lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, @@ -345,7 +357,9 @@ def create_lora_weights( ( max_loras, self.base_layer.local_num_experts, - lora_config.max_lora_rank, + lora_config.max_lora_rank + if not self.fully_sharded + else divide(lora_config.max_lora_rank, self.tp_size), self.base_layer.hidden_size, ), dtype=lora_config.lora_dtype, @@ -419,6 +433,20 @@ def set_lora( w3_lora_b = w3_lora_b[start_idx:end_idx, :] w2_lora_a = w2_lora_a[:, start_idx:end_idx] + if self.fully_sharded: + # Based on S-LoRA, we slice W1 and W3 A along the rank dim, + # and W2 B along the hidden_size dim. + w13_shard_size = self.w1_lora_a_stacked[index, eid].shape[0] + w13_start_idx = self.tp_rank * w13_shard_size + w13_end_idx = (self.tp_rank + 1) * w13_shard_size + w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :] + w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :] + + w2_shard_size = self.w2_lora_b_stacked[index, eid].shape[0] + w2_start_idx = self.tp_rank * w2_shard_size + w2_end_idx = (self.tp_rank + 1) * w2_shard_size + w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :] + self.w1_lora_a_stacked[ index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1] ].copy_(w1_lora_a, non_blocking=True) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 893972144e99..0e320cf709b9 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -3,6 +3,10 @@ import torch +from vllm.distributed import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import direct_register_custom_op @@ -311,6 +315,7 @@ def _fused_moe_lora_expand( num_stages: int, split_k: int, mul_routed_weight: bool = False, + offset: int = 0, ) -> None: b_ptr = _get_ptr(lora_b_stacked, device) K = max_lora_rank @@ -380,7 +385,7 @@ def _fused_moe_lora_expand( **expand_config, ) for i in range(num_slices): - output[:, :, i * N : (i + 1) * N] += b_intermediate_cache1[i] + output[:, :, i * N + offset : (i + 1) * N + offset] += b_intermediate_cache1[i] @torch.inference_mode() @@ -416,6 +421,8 @@ def _fused_moe_lora( expand_num_stages: int, expand_split_k: int, mul_routed_weight: bool = False, + fully_sharded: bool = False, + offset: int = 0, ) -> None: assert len(lora_a_stacked) == len(lora_b_stacked) > 0 assert ( @@ -430,7 +437,6 @@ def _fused_moe_lora( == expert_ids.shape[0] == num_tokens_post_padded.shape[0] ) - assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1] assert output.shape[0] == topk_weights.shape[0] assert top_k_num == topk_weights.shape[1] device = qcurr_hidden_states.device @@ -480,6 +486,16 @@ def _fused_moe_lora( mul_routed_weight, ) + if fully_sharded: + if max_lora_rank == w1_lora_b_stacked.shape[-1]: + a_intermediate_cache1 = tensor_model_parallel_all_reduce( + a_intermediate_cache1 + ) + else: + a_intermediate_cache1 = tensor_model_parallel_all_gather( + a_intermediate_cache1 + ) + _fused_moe_lora_expand( output, a_intermediate_cache1, @@ -510,6 +526,7 @@ def _fused_moe_lora( expand_num_stages, expand_split_k, mul_routed_weight, + offset, ) diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index b6186e856152..a6ffbb7b71ce 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -483,6 +483,8 @@ def add_lora_fused_moe( expand_config, adapter_enabled: torch.Tensor, mul_routed_weight=False, + fully_sharded: bool = False, + offset: int = 0, ): """ Performs a fused forward computation for LoRA of diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index ede50a48af98..d863a5884d3c 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -375,6 +375,8 @@ def add_lora_fused_moe( expand_config, adapter_enabled: torch.Tensor, mul_routed_weight=False, + fully_sharded: bool = False, + offset: int = 0, ): """ Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer. @@ -408,4 +410,6 @@ def add_lora_fused_moe( expand_config.get("NUM_STAGES", 3), expand_config.get("SPLIT_K", 1), mul_routed_weight, + fully_sharded, + offset, ) From be495c2a46c538c88b3c130004792444e0247f7a Mon Sep 17 00:00:00 2001 From: gnovack Date: Tue, 18 Nov 2025 00:11:46 +0000 Subject: [PATCH 2/2] add test case for fully-sharded _fused_moe_lora Signed-off-by: gnovack --- tests/lora/test_fused_moe_lora_kernel.py | 208 +++++++++++++++++- vllm/lora/layers/fused_moe.py | 2 +- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 3 + 3 files changed, 210 insertions(+), 3 deletions(-) diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index 91ab4a87c65f..91c8b861c3c5 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -1,13 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os import random import pytest import torch +from tests.utils import multi_gpu_test from vllm import _custom_ops as ops +from vllm.distributed import ( + init_distributed_environment, + initialize_model_parallel, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_world_size, +) from vllm.lora.ops.triton_ops import fused_moe_lora from vllm.platforms import current_platform +from vllm.utils.network_utils import get_open_port @pytest.fixture(autouse=True) @@ -122,6 +134,8 @@ def use_fused_moe_lora_kernel( max_loras, num_experts, block_size, + fully_sharded=False, + offset=0, ): max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) @@ -195,10 +209,10 @@ def use_fused_moe_lora_kernel( config["NUM_STAGES"], config["SPLIT_K"], mul_routed_weight, + fully_sharded=fully_sharded, + offset=offset, ) - return output - def use_torch( hidden_states, @@ -317,3 +331,193 @@ def test_fused_moe_lora_kernel( ) torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("num_tokens", [100]) +@pytest.mark.parametrize("top_k_num", [6]) +@pytest.mark.parametrize("num_experts", [64]) +@pytest.mark.parametrize("max_loras", [4]) +@pytest.mark.parametrize("N", [1408]) +@pytest.mark.parametrize("K", [2048]) +@pytest.mark.parametrize("max_lora_rank", [16, 32, 64]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("column_parallel", [True, False]) +def test_fused_moe_lora_kernel_fully_sharded( + num_tokens, + top_k_num, + num_experts, + max_loras, + N, + K, + max_lora_rank, + block_size, + dtype, + seed, + column_parallel, +): + current_platform.seed_everything(seed) + # the number of randomly generated sentences. + num_sequences = 10 + # generate data + topk_ids, topk_weights, token_lora_mapping = sample_data( + num_tokens, num_sequences, max_loras, num_experts, top_k_num + ) + + def run_torch_spawn(fn, nprocs): + torch.multiprocessing.spawn( + fn, + args=( + nprocs, + f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}", + dtype, + seed, + N, + K, + num_tokens, + topk_ids, + topk_weights, + token_lora_mapping, + max_lora_rank, + top_k_num, + max_loras, + num_experts, + block_size, + column_parallel, + ), + nprocs=nprocs, + ) + + run_torch_spawn(use_fused_moe_lora_kernel_tensor_parallel, nprocs=2) + + +def use_fused_moe_lora_kernel_tensor_parallel( + local_rank, + world_size, + init_method, + dtype, + seed, + N, + K, + num_tokens, + topk_ids, + topk_weights, + token_lora_mapping, + max_lora_rank, + top_k_num, + max_loras, + num_experts, + block_size, + column_parallel, +): + def _get_shard_slice(shard_size): + return slice(local_rank * shard_size, (local_rank + 1) * shard_size) + + current_platform.seed_everything(seed) + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + init_distributed_environment( + world_size=world_size, + rank=local_rank, + local_rank=local_rank, + distributed_init_method=init_method, + ) + initialize_model_parallel(world_size, 1) + tp_size = get_tensor_model_parallel_world_size() + + input_dim = K if column_parallel else N + output_dim = N if column_parallel else K + + # init lora weights + lora_a = torch.rand( + ( + max_loras, + num_experts, + max_lora_rank, + input_dim, + ), + dtype=dtype, + ) + lora_b = torch.rand( + ( + max_loras, + num_experts, + output_dim, + max_lora_rank, + ), + dtype=dtype, + ) + + hidden_states = torch.rand( + ( + num_tokens, + input_dim, + ), + dtype=dtype, + ) + + output = torch.zeros((num_tokens, top_k_num, output_dim), dtype=dtype) + topk_ids = topk_ids.to(device) + topk_weights = topk_weights.to(device) + token_lora_mapping = token_lora_mapping.to(device) + + ref_output = use_torch( + hidden_states, + token_lora_mapping, + topk_ids, + [lora_a], + [lora_b], + top_k_num, + ) + + if column_parallel: + # Column parallel (e.g. gate_up_proj): LoRA A is sliced along the rank dim, + # and Lora B is sliced along the output dim + lora_a_shard_size = max_lora_rank // tp_size + lora_a = lora_a[:, :, _get_shard_slice(lora_a_shard_size), :] + max_lora_rank = lora_a_shard_size + offset = 0 + + lora_b_shard_size = output_dim // tp_size + lora_b = lora_b[:, :, _get_shard_slice(lora_b_shard_size), :] + output = output[:, :, _get_shard_slice(lora_b_shard_size)].contiguous() + else: + # Row parallel (e.g. down proj): LoRA A is sliced along the input dim, + # and LoRA B is sliced along the output dim + lora_a_shard_size = input_dim // tp_size + lora_a = lora_a[:, :, :, _get_shard_slice(lora_a_shard_size)] + hidden_states = hidden_states[:, _get_shard_slice(lora_a_shard_size)] + + lora_b_shard_size = output_dim // tp_size + lora_b = lora_b[:, :, _get_shard_slice(lora_b_shard_size), :] + offset = lora_b_shard_size * local_rank + + use_fused_moe_lora_kernel( + topk_ids, + topk_weights, + token_lora_mapping, + max_lora_rank, + top_k_num, + [lora_a], + [lora_b], + hidden_states, + output, + max_loras, + num_experts, + block_size, + fully_sharded=True, + offset=offset, + ) + + if column_parallel: + output = tensor_model_parallel_all_gather(output) + else: + output = tensor_model_parallel_all_reduce(output) + + torch.testing.assert_close(output, ref_output, atol=1e-1, rtol=1e-1) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index dc87950106d5..3291c41fcda1 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -252,7 +252,7 @@ def wrapper(*args, **kwargs): sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1) intermediate_cache2 = moe_state_dict["intermediate_cache2"] intermediate_cache3 = args[0] - max_lora_rank = self.w2_lora_b_stacked.shape[-1] + max_lora_rank = self.w2_lora_a_stacked.shape[-2] shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 0e320cf709b9..a986b8072382 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -496,6 +496,9 @@ def _fused_moe_lora( a_intermediate_cache1 ) + # reset max_lora_rank to the full rank after allgather + max_lora_rank = a_intermediate_cache1.shape[-1] + _fused_moe_lora_expand( output, a_intermediate_cache1,