Skip to content

Commit 294c805

Browse files
gnovackjeejeelee
andauthored
Early exit for MoE LoRA kernels (vllm-project#27131)
Signed-off-by: gnovack <gnovack@amazon.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent 40b69e3 commit 294c805

File tree

11 files changed

+123
-34
lines changed

11 files changed

+123
-34
lines changed

csrc/moe/moe_lora_align_sum_kernels.cu

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,16 @@ __global__ void moe_lora_align_sum_kernel(
2828
int64_t block_size, int num_experts, int max_loras, size_t numel,
2929
int max_num_tokens_padded, int max_num_m_blocks,
3030
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
31-
int topk_num, int32_t* total_tokens_post_pad) {
31+
int topk_num, int32_t* total_tokens_post_pad, int32_t* adapter_enabled,
32+
int32_t* lora_ids) {
3233
const size_t tokens_per_thread = div_ceil(numel, blockDim.x);
3334
const size_t start_idx = threadIdx.x * tokens_per_thread;
3435

35-
int lora_id = blockIdx.x;
36+
int lora_idx = blockIdx.x;
37+
int lora_id = lora_ids[lora_idx];
38+
if (lora_id == -1 || adapter_enabled[lora_id] == 0) {
39+
return;
40+
}
3641
extern __shared__ int32_t shared_mem[];
3742
int32_t* cumsum = shared_mem;
3843
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1);
@@ -121,14 +126,13 @@ __global__ void moe_lora_align_sum_kernel(
121126
}
122127
}
123128

124-
void moe_lora_align_block_size(torch::Tensor topk_ids,
125-
torch::Tensor token_lora_mapping,
126-
int64_t num_experts, int64_t block_size,
127-
int64_t max_loras, int64_t max_num_tokens_padded,
128-
int64_t max_num_m_blocks,
129-
torch::Tensor sorted_token_ids,
130-
torch::Tensor expert_ids,
131-
torch::Tensor num_tokens_post_pad) {
129+
void moe_lora_align_block_size(
130+
torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
131+
int64_t num_experts, int64_t block_size, int64_t max_loras,
132+
int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
133+
torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
134+
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
135+
torch::Tensor lora_ids) {
132136
const int topk_num = topk_ids.size(1);
133137

134138
TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
@@ -164,6 +168,7 @@ void moe_lora_align_block_size(torch::Tensor topk_ids,
164168
max_loras, topk_ids.numel(), max_num_tokens_padded,
165169
max_num_m_blocks, sorted_token_ids.data_ptr<int32_t>(),
166170
expert_ids.data_ptr<int32_t>(), topk_num,
167-
num_tokens_post_pad.data_ptr<int32_t>());
171+
num_tokens_post_pad.data_ptr<int32_t>(),
172+
adapter_enabled.data_ptr<int32_t>(), lora_ids.data_ptr<int32_t>());
168173
});
169174
}

csrc/moe/moe_ops.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,13 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch,
2020
torch::Tensor expert_ids,
2121
torch::Tensor num_tokens_post_pad);
2222

23-
void moe_lora_align_block_size(torch::Tensor topk_ids,
24-
torch::Tensor token_lora_mapping,
25-
int64_t num_experts, int64_t block_size,
26-
int64_t max_loras, int64_t max_num_tokens_padded,
27-
int64_t max_num_m_blocks,
28-
torch::Tensor sorted_token_ids,
29-
torch::Tensor expert_ids,
30-
torch::Tensor num_tokens_post_pad);
23+
void moe_lora_align_block_size(
24+
torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
25+
int64_t num_experts, int64_t block_size, int64_t max_loras,
26+
int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
27+
torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
28+
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
29+
torch::Tensor lora_ids);
3130
#ifndef USE_ROCM
3231
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
3332
torch::Tensor b_qweight, torch::Tensor b_scales,

csrc/moe/torch_bindings.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
4444
" int max_num_m_blocks, "
4545
" Tensor !sorted_token_ids,"
4646
" Tensor !experts_ids,"
47-
" Tensor !num_tokens_post_pad) -> () ");
47+
" Tensor !num_tokens_post_pad,"
48+
" Tensor !adapter_enabled,"
49+
" Tensor !lora_ids) -> () ");
4850
m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size);
4951

5052
#ifndef USE_ROCM

tests/lora/test_fused_moe_lora_kernel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ def use_fused_moe_lora_kernel(
134134
)
135135
expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32)
136136
num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32)
137+
adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32)
138+
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32)
137139

138140
# call kernel
139141
ops.moe_lora_align_block_size(
@@ -147,6 +149,8 @@ def use_fused_moe_lora_kernel(
147149
sorted_token_ids,
148150
expert_ids,
149151
num_tokens_post_padded,
152+
adapter_enabled,
153+
lora_ids,
150154
)
151155

152156
config = {
@@ -172,6 +176,8 @@ def use_fused_moe_lora_kernel(
172176
num_tokens_post_padded,
173177
max_lora_rank,
174178
top_k_num,
179+
lora_ids,
180+
adapter_enabled,
175181
config["BLOCK_SIZE_M"],
176182
config["BLOCK_SIZE_N"],
177183
config["BLOCK_SIZE_K"],

tests/lora/test_moe_lora_align_sum.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def test_moe_lora_align_block_size(
6060
(max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda"
6161
)
6262
num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda")
63+
adapter_enabled = torch.ones((max_loras + 1,), dtype=torch.int32, device="cuda")
64+
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device="cuda")
6365

6466
# call kernel
6567
ops.moe_lora_align_block_size(
@@ -73,6 +75,8 @@ def test_moe_lora_align_block_size(
7375
sorted_token_ids,
7476
expert_ids,
7577
num_tokens_post_pad,
78+
adapter_enabled,
79+
lora_ids,
7680
)
7781

7882
# verify values

tests/lora/test_olmoe_tp.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
45
import vllm
56
from vllm.lora.request import LoRARequest
67

@@ -28,8 +29,17 @@
2829
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
2930
]
3031

32+
EXPECTED_BASE_MODEL_OUTPUT = [
33+
"SELECT COUNT(Candidate_ID) FROM candidate",
34+
"SELECT COUNT(Candidate_ID) FROM candidate",
35+
"SELECT Candidate_ID, COUNT(*) as Total_Candidates\nFROM candidate\nINNER JOIN people ON candidate.People_ID = people.People_ID", # noqa: E501
36+
"SELECT Candidate_ID, Poll_Source FROM candidate WHERE People_ID IN (SELECT People_ID FROM people) ORDER BY COUNT(*) DESC LIMIT 1", # noqa: E501
37+
]
38+
3139

32-
def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
40+
def generate_and_test(
41+
llm: vllm.LLM, lora_path: str, lora_id: list[int | None] | int | None
42+
) -> None:
3343
prompts = [
3444
PROMPT_TEMPLATE.format(context="How many candidates are there?"),
3545
PROMPT_TEMPLATE.format(context="Count the number of candidates."),
@@ -40,12 +50,18 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
4050
context="Return the poll resource associated with the most candidates."
4151
),
4252
]
53+
54+
lora_request = None
55+
if isinstance(lora_id, int):
56+
lora_request = LoRARequest(str(lora_id), lora_id, lora_path)
57+
elif isinstance(lora_id, list):
58+
lora_request = [
59+
LoRARequest(str(i), i, lora_path) if i is not None else None
60+
for i in lora_id
61+
]
62+
4363
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
44-
outputs = llm.generate(
45-
prompts,
46-
sampling_params,
47-
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
48-
)
64+
outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)
4965
# Print the outputs.
5066
generated_texts: list[str] = []
5167
for output in outputs:
@@ -55,7 +71,13 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
5571
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
5672

5773
for i in range(len(EXPECTED_LORA_OUTPUT)):
58-
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
74+
req_lora_id = lora_id[i] if isinstance(lora_id, list) else lora_id
75+
expected_output = (
76+
EXPECTED_LORA_OUTPUT[i]
77+
if req_lora_id is not None
78+
else EXPECTED_BASE_MODEL_OUTPUT[i]
79+
)
80+
assert generated_texts[i].startswith(expected_output)
5981

6082

6183
def test_olmoe_lora(olmoe_lora_files):
@@ -75,6 +97,20 @@ def test_olmoe_lora(olmoe_lora_files):
7597
generate_and_test(llm, olmoe_lora_files, lora_id=2)
7698

7799

100+
def test_olmoe_lora_mixed(olmoe_lora_files):
101+
llm = vllm.LLM(
102+
MODEL_PATH,
103+
max_model_len=1024,
104+
enable_lora=True,
105+
max_loras=4,
106+
enforce_eager=True,
107+
trust_remote_code=True,
108+
enable_chunked_prefill=True,
109+
)
110+
111+
generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None])
112+
113+
78114
@multi_gpu_test(num_gpus=2)
79115
def test_olmoe_lora_tp2(olmoe_lora_files):
80116
llm = vllm.LLM(

vllm/_custom_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1823,6 +1823,8 @@ def moe_lora_align_block_size(
18231823
sorted_token_ids: torch.Tensor,
18241824
experts_ids: torch.Tensor,
18251825
num_tokens_post_pad: torch.Tensor,
1826+
adapter_enabled: torch.Tensor,
1827+
lora_ids: torch.Tensor,
18261828
) -> None:
18271829
torch.ops._moe_C.moe_lora_align_block_size(
18281830
topk_ids,
@@ -1835,6 +1837,8 @@ def moe_lora_align_block_size(
18351837
sorted_token_ids,
18361838
experts_ids,
18371839
num_tokens_post_pad,
1840+
adapter_enabled,
1841+
lora_ids,
18381842
)
18391843

18401844

vllm/lora/layers/fused_moe.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def wrapper(*args, **kwargs):
111111
config["BLOCK_SIZE_M"],
112112
self.base_layer.local_num_experts,
113113
max_loras,
114+
self.adapter_enabled,
114115
expert_map,
115116
)
116117

@@ -138,6 +139,7 @@ def wrapper(*args, **kwargs):
138139
max_lora_rank,
139140
top_k,
140141
config,
142+
self.adapter_enabled,
141143
)
142144

143145
result = func(*args, **kwargs)
@@ -196,6 +198,7 @@ def wrapper(*args, **kwargs):
196198
max_lora_rank,
197199
top_k,
198200
config,
201+
self.adapter_enabled,
199202
True,
200203
)
201204

@@ -227,6 +230,10 @@ def create_lora_weights(
227230
) -> None:
228231
"""Initializes lora matrices."""
229232

233+
self.adapter_enabled = torch.tensor(
234+
[0] * (max_loras + 1), dtype=torch.int, device=self.device
235+
)
236+
230237
self.w1_lora_a_stacked = torch.zeros(
231238
(
232239
max_loras,
@@ -313,6 +320,7 @@ def reset_lora(self, index: int):
313320
self.w3_lora_b_stacked[index] = 0
314321
self.w2_lora_a_stacked[index] = 0
315322
self.w2_lora_b_stacked[index] = 0
323+
self.adapter_enabled[index] = 0
316324

317325
def set_lora(
318326
self,
@@ -322,8 +330,9 @@ def set_lora(
322330
embeddings_tensor: torch.Tensor | None,
323331
bias: torch.Tensor | None = None,
324332
):
325-
self.reset_lora(index)
326333
"""Overwrites lora tensors at index."""
334+
self.reset_lora(index)
335+
self.adapter_enabled[index] = 1
327336
for eid in range(len(lora_a) // 3):
328337
w1_lora_a = lora_a[eid * 3]
329338
w2_lora_a = lora_a[eid * 3 + 1]

vllm/lora/ops/triton_ops/fused_moe_lora_op.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def _fused_moe_lora_kernel(
5454
EM,
5555
num_valid_tokens,
5656
num_experts,
57+
lora_ids,
58+
adapter_enabled,
5759
# The stride variables represent how much to increase the ptr by when
5860
# moving by 1 element in a particular dimension. E.g. `stride_am` is
5961
# how much to increase `a_ptr` by to get the element one row down
@@ -84,6 +86,11 @@ def _fused_moe_lora_kernel(
8486
pid = tl.program_id(axis=0)
8587
slice_id = tl.program_id(axis=1)
8688
lora_idx = tl.program_id(axis=2)
89+
lora_id = tl.load(lora_ids + lora_idx)
90+
moe_enabled = tl.load(adapter_enabled + lora_id)
91+
if lora_id == -1 or moe_enabled == 0:
92+
# Early exit for the no-lora case.
93+
return
8794
max_loras = tl.num_programs(axis=2)
8895
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
8996

@@ -100,12 +107,12 @@ def _fused_moe_lora_kernel(
100107
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
101108
pid_n = (pid_m_n % num_pid_in_group) // group_size_m
102109

103-
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx)
110+
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_id)
104111
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
105112
return
106113

107114
# get the expert_id to process curr shard
108-
ind = lora_idx * stride_el + pid_m
115+
ind = lora_id * stride_el + pid_m
109116
expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1)
110117
if expert_id == -1:
111118
return
@@ -119,7 +126,7 @@ def _fused_moe_lora_kernel(
119126
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
120127

121128
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
122-
token_ind = stride_tl * lora_idx + offs_token_id
129+
token_ind = stride_tl * lora_id + offs_token_id
123130
offs_token = tl.load(
124131
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0
125132
)
@@ -132,7 +139,7 @@ def _fused_moe_lora_kernel(
132139

133140
b_ptrs = (
134141
cur_b_ptr
135-
+ lora_idx * stride_bl
142+
+ lora_id * stride_bl
136143
+ expert_id * stride_be
137144
+ offs_k[:, None] * stride_bk
138145
+ offs_bn[None, :] * stride_bn
@@ -184,6 +191,8 @@ def _fused_moe_lora(
184191
num_tokens_post_padded: torch.Tensor, # (max_loras, )
185192
max_lora_rank: int,
186193
top_k_num: int,
194+
lora_ids: torch.Tensor,
195+
adapter_enabled: torch.Tensor,
187196
block_size_m: int,
188197
block_size_n: int,
189198
block_size_k: int,
@@ -234,7 +243,7 @@ def _fused_moe_lora(
234243
num_tokens = M * top_k_num
235244
w1_output_dim_size = w1_lora_b_stacked.shape[2]
236245

237-
lora_intermediate_cache1 = torch.empty(
246+
lora_intermediate_cache1 = torch.zeros(
238247
(num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)),
239248
dtype=output.dtype,
240249
device=device,
@@ -272,6 +281,8 @@ def _fused_moe_lora(
272281
EM,
273282
num_tokens,
274283
num_experts,
284+
lora_ids,
285+
adapter_enabled,
275286
qcurr_hidden_states.stride(0),
276287
qcurr_hidden_states.stride(1),
277288
w1_lora_a_stacked.stride(0),
@@ -319,6 +330,8 @@ def _fused_moe_lora(
319330
EM,
320331
num_tokens,
321332
num_experts,
333+
lora_ids,
334+
adapter_enabled,
322335
a_intermediate_cache1.stride(0),
323336
a_intermediate_cache1.stride(1),
324337
w1_lora_b_stacked.stride(0),
@@ -352,6 +365,8 @@ def _fused_moe_lora_fake(
352365
num_tokens_post_padded: torch.Tensor,
353366
max_lora_rank: int,
354367
top_k_num: int,
368+
lora_ids: torch.Tensor,
369+
adapter_enabled: torch.Tensor,
355370
block_size_m: int,
356371
block_size_n: int,
357372
block_size_k: int,

0 commit comments

Comments
 (0)