@@ -54,6 +54,8 @@ def _fused_moe_lora_kernel(
5454 EM ,
5555 num_valid_tokens ,
5656 num_experts ,
57+ lora_ids ,
58+ adapter_enabled ,
5759 # The stride variables represent how much to increase the ptr by when
5860 # moving by 1 element in a particular dimension. E.g. `stride_am` is
5961 # how much to increase `a_ptr` by to get the element one row down
@@ -84,6 +86,11 @@ def _fused_moe_lora_kernel(
8486 pid = tl .program_id (axis = 0 )
8587 slice_id = tl .program_id (axis = 1 )
8688 lora_idx = tl .program_id (axis = 2 )
89+ lora_id = tl .load (lora_ids + lora_idx )
90+ moe_enabled = tl .load (adapter_enabled + lora_id )
91+ if lora_id == - 1 or moe_enabled == 0 :
92+ # Early exit for the no-lora case.
93+ return
8794 max_loras = tl .num_programs (axis = 2 )
8895 grid_k = tl .cdiv (K , BLOCK_SIZE_K * SPLIT_K )
8996
@@ -100,12 +107,12 @@ def _fused_moe_lora_kernel(
100107 pid_m = first_pid_m + ((pid_m_n % num_pid_in_group ) % group_size_m )
101108 pid_n = (pid_m_n % num_pid_in_group ) // group_size_m
102109
103- num_tokens_post_padded = tl .load (num_tokens_post_padded_ptr + lora_idx )
110+ num_tokens_post_padded = tl .load (num_tokens_post_padded_ptr + lora_id )
104111 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded :
105112 return
106113
107114 # get the expert_id to process curr shard
108- ind = lora_idx * stride_el + pid_m
115+ ind = lora_id * stride_el + pid_m
109116 expert_id = tl .load (expert_ids_ptr + ind , ind < max_loras * stride_el , - 1 )
110117 if expert_id == - 1 :
111118 return
@@ -119,7 +126,7 @@ def _fused_moe_lora_kernel(
119126 offs_k = pid_sk * BLOCK_SIZE_K + tl .arange (0 , BLOCK_SIZE_K )
120127
121128 offs_token_id = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M ).to (tl .int64 )
122- token_ind = stride_tl * lora_idx + offs_token_id
129+ token_ind = stride_tl * lora_id + offs_token_id
123130 offs_token = tl .load (
124131 sorted_token_ids_ptr + token_ind , token_ind < max_loras * stride_tl , 0
125132 )
@@ -132,7 +139,7 @@ def _fused_moe_lora_kernel(
132139
133140 b_ptrs = (
134141 cur_b_ptr
135- + lora_idx * stride_bl
142+ + lora_id * stride_bl
136143 + expert_id * stride_be
137144 + offs_k [:, None ] * stride_bk
138145 + offs_bn [None , :] * stride_bn
@@ -184,6 +191,8 @@ def _fused_moe_lora(
184191 num_tokens_post_padded : torch .Tensor , # (max_loras, )
185192 max_lora_rank : int ,
186193 top_k_num : int ,
194+ lora_ids : torch .Tensor ,
195+ adapter_enabled : torch .Tensor ,
187196 block_size_m : int ,
188197 block_size_n : int ,
189198 block_size_k : int ,
@@ -234,7 +243,7 @@ def _fused_moe_lora(
234243 num_tokens = M * top_k_num
235244 w1_output_dim_size = w1_lora_b_stacked .shape [2 ]
236245
237- lora_intermediate_cache1 = torch .empty (
246+ lora_intermediate_cache1 = torch .zeros (
238247 (num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size )),
239248 dtype = output .dtype ,
240249 device = device ,
@@ -272,6 +281,8 @@ def _fused_moe_lora(
272281 EM ,
273282 num_tokens ,
274283 num_experts ,
284+ lora_ids ,
285+ adapter_enabled ,
275286 qcurr_hidden_states .stride (0 ),
276287 qcurr_hidden_states .stride (1 ),
277288 w1_lora_a_stacked .stride (0 ),
@@ -319,6 +330,8 @@ def _fused_moe_lora(
319330 EM ,
320331 num_tokens ,
321332 num_experts ,
333+ lora_ids ,
334+ adapter_enabled ,
322335 a_intermediate_cache1 .stride (0 ),
323336 a_intermediate_cache1 .stride (1 ),
324337 w1_lora_b_stacked .stride (0 ),
@@ -352,6 +365,8 @@ def _fused_moe_lora_fake(
352365 num_tokens_post_padded : torch .Tensor ,
353366 max_lora_rank : int ,
354367 top_k_num : int ,
368+ lora_ids : torch .Tensor ,
369+ adapter_enabled : torch .Tensor ,
355370 block_size_m : int ,
356371 block_size_n : int ,
357372 block_size_k : int ,
0 commit comments