Skip to content

Commit eb00ffd

Browse files
mgoindevpatelio
authored andcommitted
[Bugfix] Disable shared expert overlap if Marlin MoE is used (vllm-project#28410)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent 35beebc commit eb00ffd

File tree

6 files changed

+13
-5
lines changed

6 files changed

+13
-5
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,10 @@ def use_flashinfer_cutlass_kernels(self):
678678
and self.moe_config.use_flashinfer_cutlass_kernels
679679
)
680680

681+
@property
682+
def use_marlin_kernels(self):
683+
return getattr(self.quant_method, "use_marlin", False)
684+
681685
@property
682686
def use_dp_chunking(self) -> bool:
683687
return (

vllm/model_executor/layers/fused_moe/shared_fused_moe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,17 @@ def __init__(
2828
super().__init__(**kwargs)
2929
self._shared_experts = shared_experts
3030

31-
# Disable shared expert overlap if we are using eplb, because of
32-
# correctness issues, or if using flashinfer with DP, since there
33-
# is nothing to be gained in this case. Disabling the overlap
34-
# optimization also prevents the shared experts from being hidden
35-
# from torch.compile.
31+
# Disable shared expert overlap if:
32+
# - we are using eplb, because of correctness issues
33+
# - we are using flashinfer with DP, since there nothint to gain
34+
# - we are using marlin kjernels
3635
self.use_overlapped = (
3736
use_overlapped
3837
and not (
3938
# TODO(wentao): find the root cause and remove this condition
4039
self.enable_eplb
4140
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
41+
or self.use_marlin_kernels
4242
)
4343
and self._shared_experts is not None
4444
)

vllm/model_executor/layers/quantization/awq_marlin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ def __init__(
424424
if self.quant_config.weight_bits != 4:
425425
raise ValueError("AWQMoEMethod only supports 4bit now.")
426426
self.quant_type = scalar_types.uint4
427+
self.use_marlin = True
427428

428429
def create_weights(
429430
self,

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,6 +1342,7 @@ def __init__(
13421342
f"{WNA16_SUPPORTED_BITS}",
13431343
)
13441344
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits]
1345+
self.use_marlin = True
13451346

13461347
def create_weights(
13471348
self,

vllm/model_executor/layers/quantization/gptq_marlin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ def __init__(
482482
self.quant_type = scalar_types.uint8b128
483483
else:
484484
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
485+
self.use_marlin = True
485486

486487
def create_weights(
487488
self,

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
216216
def __init__(self, moe: FusedMoEConfig):
217217
super().__init__(moe)
218218
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
219+
self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN
219220
self.max_capture_size = (
220221
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
221222
)

0 commit comments

Comments
 (0)