44from typing import Optional
55
66import torch
7+ from typing_extensions import override
78
89import vllm ._custom_ops as ops
10+ import vllm .model_executor .layers .fused_moe .modular_kernel as mk
11+ from vllm .model_executor .layers .fused_moe .config import FusedMoEQuantConfig
912from vllm .model_executor .layers .fused_moe .fused_moe import moe_align_block_size
13+ from vllm .model_executor .layers .fused_moe .topk_weight_and_reduce import (
14+ TopKWeightAndReduceNoOP )
15+ from vllm .model_executor .layers .fused_moe .utils import _resize_cache
1016from vllm .model_executor .layers .quantization .utils .marlin_utils import (
11- marlin_make_workspace_new , maybe_warn_marlin_atomic_add )
17+ marlin_make_workspace_new , marlin_moe_intermediate_size ,
18+ maybe_warn_marlin_atomic_add )
1219from vllm .scalar_type import ScalarType , scalar_types
1320from vllm .utils import direct_register_custom_op
1421
@@ -20,7 +27,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
2027 bias2 : Optional [torch .Tensor ],
2128 w1_scale : torch .Tensor ,
2229 w2_scale : torch .Tensor ,
23- gating_output : torch .Tensor ,
30+ gating_output : Optional [ torch .Tensor ] ,
2431 topk_weights : torch .Tensor ,
2532 topk_ids : torch .Tensor ,
2633 quant_type_id : int ,
@@ -37,7 +44,10 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
3744 w1_zeros : Optional [torch .Tensor ] = None ,
3845 w2_zeros : Optional [torch .Tensor ] = None ,
3946 workspace : Optional [torch .Tensor ] = None ,
47+ intermediate_cache13 : Optional [torch .Tensor ] = None ,
48+ intermediate_cache2 : Optional [torch .Tensor ] = None ,
4049 is_k_full : bool = True ,
50+ output : Optional [torch .Tensor ] = None ,
4151 inplace : bool = False ) -> torch .Tensor :
4252 """
4353 This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -49,8 +59,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
4959 - w2 (torch.Tensor): The second set of expert weights.
5060 - w1_scale (torch.Tensor): Scale to be used for w1.
5161 - w2_scale (torch.Tensor): Scale to be used for w2.
52- - gating_output (torch.Tensor): The output of the gating operation
53- (before softmax).
62+ - gating_output (Optional[ torch.Tensor] ): The output of the gating
63+ operation (before softmax).
5464 - g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
5565 - g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
5666 - sort_indices1 (Optional[torch.Tensor]): The first act_order input
@@ -78,8 +88,9 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
7888 num_bits = 4 if quant_type in bit4_scalar_types else 8
7989
8090 # Check constraints.
81- assert hidden_states .shape [0 ] == gating_output .shape [
82- 0 ], "Number of tokens mismatch"
91+ if gating_output is not None :
92+ assert hidden_states .shape [0 ] == gating_output .shape [
93+ 0 ], "Number of tokens mismatch"
8394 assert hidden_states .shape [
8495 1 ] == w1 .shape [1 ] * 16 , "Hidden size mismatch w1"
8596 assert hidden_states .shape [1 ] == w2 .shape [2 ] // (
@@ -93,7 +104,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
93104
94105 M , K = hidden_states .shape
95106 E = w1 .shape [0 ]
96- N = w2 . shape [ 1 ] * 16
107+ N = marlin_moe_intermediate_size ( w1 , w2 )
97108 topk = topk_ids .shape [1 ]
98109
99110 # M block size selection logic
@@ -111,20 +122,24 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
111122 if workspace is None :
112123 workspace = marlin_make_workspace_new (hidden_states .device , 4 )
113124
114- intermediate_cache2 = torch .empty (
115- (M * topk_ids .shape [1 ], N ),
116- device = hidden_states .device ,
117- dtype = hidden_states .dtype ,
118- )
119- intermediate_cache13 = torch .empty (
120- (M * topk_ids .shape [1 ] * max (2 * N , K ), ),
121- device = hidden_states .device ,
122- dtype = hidden_states .dtype ,
123- )
124- intermediate_cache1 = intermediate_cache13 [:M * topk_ids .shape [1 ] * 2 * N ]
125- intermediate_cache1 = intermediate_cache1 .view (- 1 , 2 * N )
126- intermediate_cache3 = intermediate_cache13 [:M * topk_ids .shape [1 ] * K ]
127- intermediate_cache3 = intermediate_cache3 .view (- 1 , K )
125+ if intermediate_cache2 is None :
126+ intermediate_cache2 = torch .empty (
127+ (M * topk , N ),
128+ device = hidden_states .device ,
129+ dtype = hidden_states .dtype ,
130+ )
131+
132+ if intermediate_cache13 is None :
133+ intermediate_cache13 = torch .empty (
134+ (M * topk * max (2 * N , K ), ),
135+ device = hidden_states .device ,
136+ dtype = hidden_states .dtype ,
137+ )
138+
139+ intermediate_cache1 = _resize_cache (intermediate_cache13 ,
140+ (M * topk , 2 * N ))
141+ intermediate_cache3 = _resize_cache (intermediate_cache13 , (M * topk , K ))
142+ intermediate_cache2 = _resize_cache (intermediate_cache2 , (M * topk , N ))
128143
129144 maybe_warn_marlin_atomic_add (hidden_states .device , hidden_states .dtype )
130145 use_atomic_add = hidden_states .dtype == torch .half or \
@@ -200,18 +215,17 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
200215 use_fp32_reduce = True ,
201216 is_zp_float = False ).view (- 1 , topk , K )
202217
203- output = hidden_states if inplace else torch .empty_like (hidden_states )
204- return torch .sum (intermediate_cache3 .view (* intermediate_cache3 .shape ),
205- dim = 1 ,
206- out = output )
218+ if output is None :
219+ output = hidden_states if inplace else torch .empty_like (hidden_states )
220+ return torch .sum (intermediate_cache3 .view (- 1 , topk , K ), dim = 1 , out = output )
207221
208222
209223def fused_marlin_moe_fake (hidden_states : torch .Tensor ,
210224 w1 : torch .Tensor ,
211225 w2 : torch .Tensor ,
212226 w1_scale : torch .Tensor ,
213227 w2_scale : torch .Tensor ,
214- gating_output : torch .Tensor ,
228+ gating_output : Optional [ torch .Tensor ] ,
215229 topk_weights : torch .Tensor ,
216230 topk_ids : torch .Tensor ,
217231 quant_type_id : int ,
@@ -227,7 +241,10 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
227241 w1_zeros : Optional [torch .Tensor ] = None ,
228242 w2_zeros : Optional [torch .Tensor ] = None ,
229243 workspace : Optional [torch .Tensor ] = None ,
244+ intermediate_cache13 : Optional [torch .Tensor ] = None ,
245+ intermediate_cache2 : Optional [torch .Tensor ] = None ,
230246 is_k_full : bool = True ,
247+ output : Optional [torch .Tensor ] = None ,
231248 inplace : bool = False ) -> torch .Tensor :
232249 return torch .empty_like (hidden_states )
233250
@@ -237,3 +254,124 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
237254 op_func = fused_marlin_moe ,
238255 fake_impl = fused_marlin_moe_fake ,
239256)
257+
258+
259+ class MarlinExperts (mk .FusedMoEPermuteExpertsUnpermute ):
260+
261+ def __init__ (self , quant_config : FusedMoEQuantConfig ):
262+ # TODO (varun) : Enable activation quantization
263+ assert quant_config .use_mxfp4_w4a16 , "Supports only mxfp4_w4a16"
264+ super ().__init__ (quant_config )
265+
266+ @override
267+ def moe_problem_size (
268+ self ,
269+ a1 : torch .Tensor ,
270+ w1 : torch .Tensor ,
271+ w2 : torch .Tensor ,
272+ topk_ids : torch .Tensor ,
273+ ) -> tuple [int , int , int , int , int ]:
274+ assert w1 .dim () == 3 and w2 .dim () == 3
275+
276+ E = w1 .size (0 )
277+ K = a1 .size (- 1 )
278+ N = marlin_moe_intermediate_size (w1 , w2 )
279+
280+ if a1 .dim () == 2 :
281+ # Make sure we are using the correct a1 (pre-permute).
282+ assert topk_ids .size (0 ) == a1 .size (0 ), \
283+ f"{ topk_ids .size (0 )} != { a1 .size (0 )} "
284+ M = a1 .size (0 )
285+ else :
286+ assert a1 .dim () == 3
287+ assert a1 .size (0 ) == E , f"{ a1 .size (0 )} == { E } "
288+ M = a1 .size (1 ) # This is max_num_tokens
289+
290+ assert topk_ids .dim () == 2
291+ topk = topk_ids .size (1 )
292+
293+ return E , M , N , K , topk
294+
295+ def supports_expert_map (self ) -> bool :
296+ return True
297+
298+ def finalize_weight_and_reduce_impl (self ) -> mk .TopKWeightAndReduce :
299+ return TopKWeightAndReduceNoOP ()
300+
301+ @property
302+ def activation_formats (
303+ self
304+ ) -> tuple [mk .FusedMoEActivationFormat , mk .FusedMoEActivationFormat ]:
305+ return (mk .FusedMoEActivationFormat .Standard ,
306+ mk .FusedMoEActivationFormat .Standard )
307+
308+ def supports_chunking (self ) -> bool :
309+ return True
310+
311+ def workspace_shapes (
312+ self , a : torch .Tensor , aq : torch .Tensor , M : int , N : int , K : int ,
313+ topk : int , global_num_experts : int , local_num_experts : int ,
314+ expert_tokens_meta : Optional [mk .ExpertTokensMetadata ]
315+ ) -> tuple [tuple [int , ...], tuple [int , ...], tuple [int , ...], torch .dtype ]:
316+ # Modular Kernel provisions output buffer from workspace1. However in
317+ # the fused_marlin_moe() function, the final torch.sum(), is defined
318+ # essentially as,
319+ # `torch.sum(workspace1, dim=1, out=output)`
320+ # Having overlapping input and output tensors for torch.sum seems
321+ # error prone and depends on how the torch.sum is implemented.
322+ # For this reason we swap let the output buffer provision from
323+ # workspace2.
324+
325+ # Workspace/IntermediateCache allocation matching fused_marlin_moe()
326+ #workspace1 = (M * topk * max(2 * N, K),)
327+ #workspace2 = (M * topk, N)
328+
329+ # Workspace/IntermediateCache allocation accounting for output buffer
330+ # provisioning
331+ workspace1 = (M * topk , max (N , K ))
332+ workspace2 = (M * topk * max (2 * N , K ), )
333+ output = (M , K )
334+
335+ return (workspace1 , workspace2 , output , a .dtype )
336+
337+ def apply (
338+ self ,
339+ output : torch .Tensor ,
340+ hidden_states : torch .Tensor ,
341+ w1 : torch .Tensor ,
342+ w2 : torch .Tensor ,
343+ topk_weights : torch .Tensor ,
344+ topk_ids : torch .Tensor ,
345+ activation : str ,
346+ global_num_experts : int ,
347+ expert_map : Optional [torch .Tensor ],
348+ a1q_scale : Optional [torch .Tensor ],
349+ a2_scale : Optional [torch .Tensor ],
350+ workspace13 : torch .Tensor ,
351+ workspace2 : torch .Tensor ,
352+ expert_tokens_meta : Optional [mk .ExpertTokensMetadata ],
353+ apply_router_weight_on_input : bool ,
354+ ):
355+ assert self .w1_scale is not None
356+ assert self .w2_scale is not None
357+ return fused_marlin_moe (
358+ hidden_states = hidden_states ,
359+ w1 = w1 ,
360+ w2 = w2 ,
361+ bias1 = self .w1_bias ,
362+ bias2 = self .w2_bias ,
363+ w1_scale = self .w1_scale ,
364+ w2_scale = self .w2_scale ,
365+ gating_output = None ,
366+ topk_weights = topk_weights ,
367+ topk_ids = topk_ids ,
368+ quant_type_id = scalar_types .float4_e2m1f .id , # works only for w4a16
369+ apply_router_weight_on_input = apply_router_weight_on_input ,
370+ global_num_experts = global_num_experts ,
371+ activation = activation ,
372+ expert_map = expert_map ,
373+ output = output ,
374+ # Workspaces are swapped in workspace_shapes() to account for proper
375+ # output buffer allocation. Please refer to workspace_shapes().
376+ intermediate_cache13 = workspace2 ,
377+ intermediate_cache2 = workspace13 )
0 commit comments