Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 206 additions & 2 deletions tests/lora/test_fused_moe_lora_kernel.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import random

import pytest
import torch

from tests.utils import multi_gpu_test
from vllm import _custom_ops as ops
from vllm.distributed import (
init_distributed_environment,
initialize_model_parallel,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.lora.ops.triton_ops import fused_moe_lora
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_port


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -122,6 +134,8 @@ def use_fused_moe_lora_kernel(
max_loras,
num_experts,
block_size,
fully_sharded=False,
offset=0,
):
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
Expand Down Expand Up @@ -195,10 +209,10 @@ def use_fused_moe_lora_kernel(
config["NUM_STAGES"],
config["SPLIT_K"],
mul_routed_weight,
fully_sharded=fully_sharded,
offset=offset,
)

return output


def use_torch(
hidden_states,
Expand Down Expand Up @@ -317,3 +331,193 @@ def test_fused_moe_lora_kernel(
)

torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1)


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("num_tokens", [100])
@pytest.mark.parametrize("top_k_num", [6])
@pytest.mark.parametrize("num_experts", [64])
@pytest.mark.parametrize("max_loras", [4])
@pytest.mark.parametrize("N", [1408])
@pytest.mark.parametrize("K", [2048])
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("column_parallel", [True, False])
def test_fused_moe_lora_kernel_fully_sharded(
num_tokens,
top_k_num,
num_experts,
max_loras,
N,
K,
max_lora_rank,
block_size,
dtype,
seed,
column_parallel,
):
current_platform.seed_everything(seed)
# the number of randomly generated sentences.
num_sequences = 10
# generate data
topk_ids, topk_weights, token_lora_mapping = sample_data(
num_tokens, num_sequences, max_loras, num_experts, top_k_num
)

def run_torch_spawn(fn, nprocs):
torch.multiprocessing.spawn(
fn,
args=(
nprocs,
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
dtype,
seed,
N,
K,
num_tokens,
topk_ids,
topk_weights,
token_lora_mapping,
max_lora_rank,
top_k_num,
max_loras,
num_experts,
block_size,
column_parallel,
),
nprocs=nprocs,
)

run_torch_spawn(use_fused_moe_lora_kernel_tensor_parallel, nprocs=2)


def use_fused_moe_lora_kernel_tensor_parallel(
local_rank,
world_size,
init_method,
dtype,
seed,
N,
K,
num_tokens,
topk_ids,
topk_weights,
token_lora_mapping,
max_lora_rank,
top_k_num,
max_loras,
num_experts,
block_size,
column_parallel,
):
def _get_shard_slice(shard_size):
return slice(local_rank * shard_size, (local_rank + 1) * shard_size)

current_platform.seed_everything(seed)

device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)

init_distributed_environment(
world_size=world_size,
rank=local_rank,
local_rank=local_rank,
distributed_init_method=init_method,
)
initialize_model_parallel(world_size, 1)
tp_size = get_tensor_model_parallel_world_size()

input_dim = K if column_parallel else N
output_dim = N if column_parallel else K

# init lora weights
lora_a = torch.rand(
(
max_loras,
num_experts,
max_lora_rank,
input_dim,
),
dtype=dtype,
)
lora_b = torch.rand(
(
max_loras,
num_experts,
output_dim,
max_lora_rank,
),
dtype=dtype,
)

hidden_states = torch.rand(
(
num_tokens,
input_dim,
),
dtype=dtype,
)

output = torch.zeros((num_tokens, top_k_num, output_dim), dtype=dtype)
topk_ids = topk_ids.to(device)
topk_weights = topk_weights.to(device)
token_lora_mapping = token_lora_mapping.to(device)

ref_output = use_torch(
hidden_states,
token_lora_mapping,
topk_ids,
[lora_a],
[lora_b],
top_k_num,
)

if column_parallel:
# Column parallel (e.g. gate_up_proj): LoRA A is sliced along the rank dim,
# and Lora B is sliced along the output dim
lora_a_shard_size = max_lora_rank // tp_size
lora_a = lora_a[:, :, _get_shard_slice(lora_a_shard_size), :]
max_lora_rank = lora_a_shard_size
offset = 0

lora_b_shard_size = output_dim // tp_size
lora_b = lora_b[:, :, _get_shard_slice(lora_b_shard_size), :]
output = output[:, :, _get_shard_slice(lora_b_shard_size)].contiguous()
else:
# Row parallel (e.g. down proj): LoRA A is sliced along the input dim,
# and LoRA B is sliced along the output dim
lora_a_shard_size = input_dim // tp_size
lora_a = lora_a[:, :, :, _get_shard_slice(lora_a_shard_size)]
hidden_states = hidden_states[:, _get_shard_slice(lora_a_shard_size)]

lora_b_shard_size = output_dim // tp_size
lora_b = lora_b[:, :, _get_shard_slice(lora_b_shard_size), :]
offset = lora_b_shard_size * local_rank

use_fused_moe_lora_kernel(
topk_ids,
topk_weights,
token_lora_mapping,
max_lora_rank,
top_k_num,
[lora_a],
[lora_b],
hidden_states,
output,
max_loras,
num_experts,
block_size,
fully_sharded=True,
offset=offset,
)

if column_parallel:
output = tensor_model_parallel_all_gather(output)
else:
output = tensor_model_parallel_all_reduce(output)

torch.testing.assert_close(output, ref_output, atol=1e-1, rtol=1e-1)
10 changes: 8 additions & 2 deletions tests/lora/test_olmoe_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


import pytest

import vllm
from vllm.lora.request import LoRARequest

Expand Down Expand Up @@ -111,8 +113,9 @@ def test_olmoe_lora_mixed(olmoe_lora_files):
generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None])


@pytest.mark.parametrize("fully_sharded_loras", [False, True])
@multi_gpu_test(num_gpus=2)
def test_olmoe_lora_tp2(olmoe_lora_files):
def test_olmoe_lora_tp2(olmoe_lora_files, fully_sharded_loras):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
Expand All @@ -122,14 +125,16 @@ def test_olmoe_lora_tp2(olmoe_lora_files):
trust_remote_code=True,
enable_chunked_prefill=True,
tensor_parallel_size=2,
fully_sharded_loras=fully_sharded_loras,
)

generate_and_test(llm, olmoe_lora_files, lora_id=1)
generate_and_test(llm, olmoe_lora_files, lora_id=2)


@pytest.mark.parametrize("fully_sharded_loras", [False, True])
@multi_gpu_test(num_gpus=4)
def test_olmoe_lora_tp4(olmoe_lora_files):
def test_olmoe_lora_tp4(olmoe_lora_files, fully_sharded_loras):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
Expand All @@ -139,6 +144,7 @@ def test_olmoe_lora_tp4(olmoe_lora_files):
trust_remote_code=True,
enable_chunked_prefill=True,
tensor_parallel_size=4,
fully_sharded_loras=fully_sharded_loras,
)

generate_and_test(llm, olmoe_lora_files, lora_id=1)
Expand Down
36 changes: 32 additions & 4 deletions vllm/lora/layers/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.distributed.utils import divide
from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand Down Expand Up @@ -205,6 +206,7 @@ def wrapper(*args, **kwargs):
shrink_config, ## pass the shrink config
expand_config, ## pass the expand config
self.adapter_enabled,
fully_sharded=self.fully_sharded,
)

result = func(*args, **kwargs)
Expand Down Expand Up @@ -250,7 +252,10 @@ def wrapper(*args, **kwargs):
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
intermediate_cache3 = args[0]
max_lora_rank = self.w1_lora_a_stacked.shape[-2]
max_lora_rank = self.w2_lora_a_stacked.shape[-2]

shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)

self.punica_wrapper.add_lora_fused_moe(
intermediate_cache3,
intermediate_cache2,
Expand All @@ -266,6 +271,8 @@ def wrapper(*args, **kwargs):
expand_config, ## pass the expand config
self.adapter_enabled,
True,
fully_sharded=self.fully_sharded,
offset=shard_size_w2 * self.tp_rank if self.fully_sharded else 0,
)

result = func(*args, **kwargs)
Expand Down Expand Up @@ -294,6 +301,7 @@ def create_lora_weights(
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
self.fully_sharded = lora_config.fully_sharded_loras

self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device
Expand All @@ -303,7 +311,9 @@ def create_lora_weights(
(
max_loras,
self.base_layer.local_num_experts,
lora_config.max_lora_rank,
lora_config.max_lora_rank
if not self.fully_sharded
else divide(lora_config.max_lora_rank, self.tp_size),
self.base_layer.hidden_size,
),
dtype=lora_config.lora_dtype,
Expand Down Expand Up @@ -334,7 +344,9 @@ def create_lora_weights(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.hidden_size,
self.base_layer.hidden_size
if not self.fully_sharded
else divide(self.base_layer.hidden_size, self.tp_size),
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
Expand All @@ -345,7 +357,9 @@ def create_lora_weights(
(
max_loras,
self.base_layer.local_num_experts,
lora_config.max_lora_rank,
lora_config.max_lora_rank
if not self.fully_sharded
else divide(lora_config.max_lora_rank, self.tp_size),
self.base_layer.hidden_size,
),
dtype=lora_config.lora_dtype,
Expand Down Expand Up @@ -419,6 +433,20 @@ def set_lora(
w3_lora_b = w3_lora_b[start_idx:end_idx, :]
w2_lora_a = w2_lora_a[:, start_idx:end_idx]

if self.fully_sharded:
# Based on S-LoRA, we slice W1 and W3 A along the rank dim,
# and W2 B along the hidden_size dim.
w13_shard_size = self.w1_lora_a_stacked[index, eid].shape[0]
w13_start_idx = self.tp_rank * w13_shard_size
w13_end_idx = (self.tp_rank + 1) * w13_shard_size
w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :]

w2_shard_size = self.w2_lora_b_stacked[index, eid].shape[0]
w2_start_idx = self.tp_rank * w2_shard_size
w2_end_idx = (self.tp_rank + 1) * w2_shard_size
w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]

self.w1_lora_a_stacked[
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
].copy_(w1_lora_a, non_blocking=True)
Expand Down
Loading
Loading