Skip to content

Commit 8899397

Browse files
gnovackjeejeelee
authored andcommitted
add support for --fully-sharded-loras in fused_moe (vllm-project#28761)
Signed-off-by: gnovack <gnovack@amazon.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent c4dffd8 commit 8899397

File tree

6 files changed

+274
-10
lines changed

6 files changed

+274
-10
lines changed

tests/lora/test_fused_moe_lora_kernel.py

Lines changed: 206 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import os
34
import random
45

56
import pytest
67
import torch
78

9+
from tests.utils import multi_gpu_test
810
from vllm import _custom_ops as ops
11+
from vllm.distributed import (
12+
init_distributed_environment,
13+
initialize_model_parallel,
14+
tensor_model_parallel_all_gather,
15+
tensor_model_parallel_all_reduce,
16+
)
17+
from vllm.distributed.parallel_state import (
18+
get_tensor_model_parallel_world_size,
19+
)
920
from vllm.lora.ops.triton_ops import fused_moe_lora
1021
from vllm.platforms import current_platform
22+
from vllm.utils.network_utils import get_open_port
1123

1224

1325
@pytest.fixture(autouse=True)
@@ -122,6 +134,8 @@ def use_fused_moe_lora_kernel(
122134
max_loras,
123135
num_experts,
124136
block_size,
137+
fully_sharded=False,
138+
offset=0,
125139
):
126140
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
127141
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
@@ -195,10 +209,10 @@ def use_fused_moe_lora_kernel(
195209
config["NUM_STAGES"],
196210
config["SPLIT_K"],
197211
mul_routed_weight,
212+
fully_sharded=fully_sharded,
213+
offset=offset,
198214
)
199215

200-
return output
201-
202216

203217
def use_torch(
204218
hidden_states,
@@ -317,3 +331,193 @@ def test_fused_moe_lora_kernel(
317331
)
318332

319333
torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1)
334+
335+
336+
@multi_gpu_test(num_gpus=2)
337+
@pytest.mark.parametrize("num_tokens", [100])
338+
@pytest.mark.parametrize("top_k_num", [6])
339+
@pytest.mark.parametrize("num_experts", [64])
340+
@pytest.mark.parametrize("max_loras", [4])
341+
@pytest.mark.parametrize("N", [1408])
342+
@pytest.mark.parametrize("K", [2048])
343+
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
344+
@pytest.mark.parametrize("block_size", [16])
345+
@pytest.mark.parametrize("dtype", DTYPES)
346+
@pytest.mark.parametrize("seed", SEED)
347+
@pytest.mark.parametrize("column_parallel", [True, False])
348+
def test_fused_moe_lora_kernel_fully_sharded(
349+
num_tokens,
350+
top_k_num,
351+
num_experts,
352+
max_loras,
353+
N,
354+
K,
355+
max_lora_rank,
356+
block_size,
357+
dtype,
358+
seed,
359+
column_parallel,
360+
):
361+
current_platform.seed_everything(seed)
362+
# the number of randomly generated sentences.
363+
num_sequences = 10
364+
# generate data
365+
topk_ids, topk_weights, token_lora_mapping = sample_data(
366+
num_tokens, num_sequences, max_loras, num_experts, top_k_num
367+
)
368+
369+
def run_torch_spawn(fn, nprocs):
370+
torch.multiprocessing.spawn(
371+
fn,
372+
args=(
373+
nprocs,
374+
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
375+
dtype,
376+
seed,
377+
N,
378+
K,
379+
num_tokens,
380+
topk_ids,
381+
topk_weights,
382+
token_lora_mapping,
383+
max_lora_rank,
384+
top_k_num,
385+
max_loras,
386+
num_experts,
387+
block_size,
388+
column_parallel,
389+
),
390+
nprocs=nprocs,
391+
)
392+
393+
run_torch_spawn(use_fused_moe_lora_kernel_tensor_parallel, nprocs=2)
394+
395+
396+
def use_fused_moe_lora_kernel_tensor_parallel(
397+
local_rank,
398+
world_size,
399+
init_method,
400+
dtype,
401+
seed,
402+
N,
403+
K,
404+
num_tokens,
405+
topk_ids,
406+
topk_weights,
407+
token_lora_mapping,
408+
max_lora_rank,
409+
top_k_num,
410+
max_loras,
411+
num_experts,
412+
block_size,
413+
column_parallel,
414+
):
415+
def _get_shard_slice(shard_size):
416+
return slice(local_rank * shard_size, (local_rank + 1) * shard_size)
417+
418+
current_platform.seed_everything(seed)
419+
420+
device = torch.device(f"cuda:{local_rank}")
421+
torch.cuda.set_device(device)
422+
torch.set_default_device(device)
423+
torch.set_default_dtype(dtype)
424+
425+
init_distributed_environment(
426+
world_size=world_size,
427+
rank=local_rank,
428+
local_rank=local_rank,
429+
distributed_init_method=init_method,
430+
)
431+
initialize_model_parallel(world_size, 1)
432+
tp_size = get_tensor_model_parallel_world_size()
433+
434+
input_dim = K if column_parallel else N
435+
output_dim = N if column_parallel else K
436+
437+
# init lora weights
438+
lora_a = torch.rand(
439+
(
440+
max_loras,
441+
num_experts,
442+
max_lora_rank,
443+
input_dim,
444+
),
445+
dtype=dtype,
446+
)
447+
lora_b = torch.rand(
448+
(
449+
max_loras,
450+
num_experts,
451+
output_dim,
452+
max_lora_rank,
453+
),
454+
dtype=dtype,
455+
)
456+
457+
hidden_states = torch.rand(
458+
(
459+
num_tokens,
460+
input_dim,
461+
),
462+
dtype=dtype,
463+
)
464+
465+
output = torch.zeros((num_tokens, top_k_num, output_dim), dtype=dtype)
466+
topk_ids = topk_ids.to(device)
467+
topk_weights = topk_weights.to(device)
468+
token_lora_mapping = token_lora_mapping.to(device)
469+
470+
ref_output = use_torch(
471+
hidden_states,
472+
token_lora_mapping,
473+
topk_ids,
474+
[lora_a],
475+
[lora_b],
476+
top_k_num,
477+
)
478+
479+
if column_parallel:
480+
# Column parallel (e.g. gate_up_proj): LoRA A is sliced along the rank dim,
481+
# and Lora B is sliced along the output dim
482+
lora_a_shard_size = max_lora_rank // tp_size
483+
lora_a = lora_a[:, :, _get_shard_slice(lora_a_shard_size), :]
484+
max_lora_rank = lora_a_shard_size
485+
offset = 0
486+
487+
lora_b_shard_size = output_dim // tp_size
488+
lora_b = lora_b[:, :, _get_shard_slice(lora_b_shard_size), :]
489+
output = output[:, :, _get_shard_slice(lora_b_shard_size)].contiguous()
490+
else:
491+
# Row parallel (e.g. down proj): LoRA A is sliced along the input dim,
492+
# and LoRA B is sliced along the output dim
493+
lora_a_shard_size = input_dim // tp_size
494+
lora_a = lora_a[:, :, :, _get_shard_slice(lora_a_shard_size)]
495+
hidden_states = hidden_states[:, _get_shard_slice(lora_a_shard_size)]
496+
497+
lora_b_shard_size = output_dim // tp_size
498+
lora_b = lora_b[:, :, _get_shard_slice(lora_b_shard_size), :]
499+
offset = lora_b_shard_size * local_rank
500+
501+
use_fused_moe_lora_kernel(
502+
topk_ids,
503+
topk_weights,
504+
token_lora_mapping,
505+
max_lora_rank,
506+
top_k_num,
507+
[lora_a],
508+
[lora_b],
509+
hidden_states,
510+
output,
511+
max_loras,
512+
num_experts,
513+
block_size,
514+
fully_sharded=True,
515+
offset=offset,
516+
)
517+
518+
if column_parallel:
519+
output = tensor_model_parallel_all_gather(output)
520+
else:
521+
output = tensor_model_parallel_all_reduce(output)
522+
523+
torch.testing.assert_close(output, ref_output, atol=1e-1, rtol=1e-1)

tests/lora/test_olmoe_tp.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44

5+
import pytest
6+
57
import vllm
68
from vllm.lora.request import LoRARequest
79

@@ -111,8 +113,9 @@ def test_olmoe_lora_mixed(olmoe_lora_files):
111113
generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None])
112114

113115

116+
@pytest.mark.parametrize("fully_sharded_loras", [False, True])
114117
@multi_gpu_test(num_gpus=2)
115-
def test_olmoe_lora_tp2(olmoe_lora_files):
118+
def test_olmoe_lora_tp2(olmoe_lora_files, fully_sharded_loras):
116119
llm = vllm.LLM(
117120
MODEL_PATH,
118121
max_model_len=1024,
@@ -122,14 +125,16 @@ def test_olmoe_lora_tp2(olmoe_lora_files):
122125
trust_remote_code=True,
123126
enable_chunked_prefill=True,
124127
tensor_parallel_size=2,
128+
fully_sharded_loras=fully_sharded_loras,
125129
)
126130

127131
generate_and_test(llm, olmoe_lora_files, lora_id=1)
128132
generate_and_test(llm, olmoe_lora_files, lora_id=2)
129133

130134

135+
@pytest.mark.parametrize("fully_sharded_loras", [False, True])
131136
@multi_gpu_test(num_gpus=4)
132-
def test_olmoe_lora_tp4(olmoe_lora_files):
137+
def test_olmoe_lora_tp4(olmoe_lora_files, fully_sharded_loras):
133138
llm = vllm.LLM(
134139
MODEL_PATH,
135140
max_model_len=1024,
@@ -139,6 +144,7 @@ def test_olmoe_lora_tp4(olmoe_lora_files):
139144
trust_remote_code=True,
140145
enable_chunked_prefill=True,
141146
tensor_parallel_size=4,
147+
fully_sharded_loras=fully_sharded_loras,
142148
)
143149

144150
generate_and_test(llm, olmoe_lora_files, lora_id=1)

vllm/lora/layers/fused_moe.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
get_tensor_model_parallel_rank,
1313
get_tensor_model_parallel_world_size,
1414
)
15+
from vllm.distributed.utils import divide
1516
from vllm.lora.layers.base import BaseLayerWithLoRA
1617
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
1718
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -205,6 +206,7 @@ def wrapper(*args, **kwargs):
205206
shrink_config, ## pass the shrink config
206207
expand_config, ## pass the expand config
207208
self.adapter_enabled,
209+
fully_sharded=self.fully_sharded,
208210
)
209211

210212
result = func(*args, **kwargs)
@@ -250,7 +252,10 @@ def wrapper(*args, **kwargs):
250252
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
251253
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
252254
intermediate_cache3 = args[0]
253-
max_lora_rank = self.w1_lora_a_stacked.shape[-2]
255+
max_lora_rank = self.w2_lora_a_stacked.shape[-2]
256+
257+
shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)
258+
254259
self.punica_wrapper.add_lora_fused_moe(
255260
intermediate_cache3,
256261
intermediate_cache2,
@@ -266,6 +271,8 @@ def wrapper(*args, **kwargs):
266271
expand_config, ## pass the expand config
267272
self.adapter_enabled,
268273
True,
274+
fully_sharded=self.fully_sharded,
275+
offset=shard_size_w2 * self.tp_rank if self.fully_sharded else 0,
269276
)
270277

271278
result = func(*args, **kwargs)
@@ -294,6 +301,7 @@ def create_lora_weights(
294301
model_config: PretrainedConfig | None = None,
295302
) -> None:
296303
"""Initializes lora matrices."""
304+
self.fully_sharded = lora_config.fully_sharded_loras
297305

298306
self.adapter_enabled = torch.tensor(
299307
[0] * (max_loras + 1), dtype=torch.int, device=self.device
@@ -303,7 +311,9 @@ def create_lora_weights(
303311
(
304312
max_loras,
305313
self.base_layer.local_num_experts,
306-
lora_config.max_lora_rank,
314+
lora_config.max_lora_rank
315+
if not self.fully_sharded
316+
else divide(lora_config.max_lora_rank, self.tp_size),
307317
self.base_layer.hidden_size,
308318
),
309319
dtype=lora_config.lora_dtype,
@@ -334,7 +344,9 @@ def create_lora_weights(
334344
(
335345
max_loras,
336346
self.base_layer.local_num_experts,
337-
self.base_layer.hidden_size,
347+
self.base_layer.hidden_size
348+
if not self.fully_sharded
349+
else divide(self.base_layer.hidden_size, self.tp_size),
338350
lora_config.max_lora_rank,
339351
),
340352
dtype=lora_config.lora_dtype,
@@ -345,7 +357,9 @@ def create_lora_weights(
345357
(
346358
max_loras,
347359
self.base_layer.local_num_experts,
348-
lora_config.max_lora_rank,
360+
lora_config.max_lora_rank
361+
if not self.fully_sharded
362+
else divide(lora_config.max_lora_rank, self.tp_size),
349363
self.base_layer.hidden_size,
350364
),
351365
dtype=lora_config.lora_dtype,
@@ -419,6 +433,20 @@ def set_lora(
419433
w3_lora_b = w3_lora_b[start_idx:end_idx, :]
420434
w2_lora_a = w2_lora_a[:, start_idx:end_idx]
421435

436+
if self.fully_sharded:
437+
# Based on S-LoRA, we slice W1 and W3 A along the rank dim,
438+
# and W2 B along the hidden_size dim.
439+
w13_shard_size = self.w1_lora_a_stacked[index, eid].shape[0]
440+
w13_start_idx = self.tp_rank * w13_shard_size
441+
w13_end_idx = (self.tp_rank + 1) * w13_shard_size
442+
w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
443+
w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :]
444+
445+
w2_shard_size = self.w2_lora_b_stacked[index, eid].shape[0]
446+
w2_start_idx = self.tp_rank * w2_shard_size
447+
w2_end_idx = (self.tp_rank + 1) * w2_shard_size
448+
w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
449+
422450
self.w1_lora_a_stacked[
423451
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
424452
].copy_(w1_lora_a, non_blocking=True)

0 commit comments

Comments
 (0)