From d6eebca2adf60de70bbd3d21f978b83ec0b24740 Mon Sep 17 00:00:00 2001 From: chenbaixuan Date: Mon, 1 Dec 2025 10:21:43 +0800 Subject: [PATCH] eplb enable MTP layers Signed-off-by: chenbaixuan --- vllm_ascend/eplb/adaptor/vllm_adaptor.py | 58 +++++++++++++-- vllm_ascend/eplb/eplb_updator.py | 8 +- vllm_ascend/eplb/utils.py | 94 +++++++++++++++--------- vllm_ascend/worker/model_runner_v1.py | 26 ++++++- 4 files changed, 142 insertions(+), 44 deletions(-) diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py index 726763013f4..7dec49638ed 100644 --- a/vllm_ascend/eplb/adaptor/vllm_adaptor.py +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -21,6 +21,7 @@ import torch import torch.distributed as dist from vllm.logger import logger +from vllm.config import get_current_vllm_config from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.eplb.adaptor.abstract_adaptor import EplbAdaptor @@ -28,19 +29,21 @@ class VllmEplbAdaptor(EplbAdaptor): - def __init__(self, model, **args): + def __init__(self, model, mtp_instance=None, num_mtp_layers=0, **args): super().__init__(**args) self.model = model self.rank_id = dist.get_rank() self.world_size = dist.get_world_size() self.param_dict = dict(self.model.named_parameters()) + self.mtp_instance = mtp_instance + self.num_mtp_layers = num_mtp_layers if self.model.config.model_type == "qwen3_moe": self.num_dense_layers = 0 self.global_expert_num = self.model.config.num_experts else: self.num_dense_layers = self.model.config.first_k_dense_replace self.global_expert_num = self.model.config.n_routed_experts - self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers + self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers # MTP not included self.init_redundancy_expert = get_ascend_config( ).init_redundancy_expert @@ -53,6 +56,19 @@ def __init__(self, model, **args): else: self.expert_weight_names = ["w13_weight", "w2_weight"] + # TODO: init self.mtp_expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here + if self.mtp_instance is not None: + if any("w13_weight_offset" in name for name, _ in self.mtp_instance.named_parameters()): + self.mtp_expert_weight_names = [ + "w13_weight", "w2_weight", "w13_weight_scale", + "w13_weight_offset", "w2_weight_scale", "w2_weight_offset" + ] + else: + self.mtp_expert_weight_names = ["w13_weight", "w2_weight"] + else: + self.mtp_expert_weight_names = [] + + self.expert_map_per_layer = dict( ) # reference to expert map on device for expert map update self.expert_map_per_layer_cpu = dict( @@ -61,6 +77,12 @@ def __init__(self, model, **args): self.expert_map_per_layer[self.num_dense_layers + layer_idx] = \ self.model.get_expert_map(self.num_dense_layers + layer_idx) + # Currently, MTP only support one layer. + if self.mtp_instance is not None: + for mtp_layer_idx in range(self.num_mtp_layers): + self.expert_map_per_layer[self.num_dense_layers + self.num_moe_layers + mtp_layer_idx] = \ + self.mtp_instance.model.get_expert_map(self.num_dense_layers + self.num_moe_layers + mtp_layer_idx) + # TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved num_buffer_tensor = torch.where( self.expert_map_per_layer[self.num_dense_layers] != -1)[0].numel() @@ -76,6 +98,11 @@ def __init__(self, model, **args): for layer_idx in range(self.num_moe_layers): self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = \ self.model.get_log2phy_map(self.num_dense_layers + layer_idx) + + if self.mtp_instance is not None: + for mtp_layer_idx in range(self.num_mtp_layers): + self.log2phy_map_per_layer[self.num_dense_layers + self.num_moe_layers + mtp_layer_idx] = \ + self.mtp_instance.model.get_log2phy_map(self.num_dense_layers + self.num_moe_layers + mtp_layer_idx) self.all_topk_ids = [] @@ -103,13 +130,30 @@ def init_expert_param_per_layer(self): name].data[local_expert_id] for name in self.expert_weight_names ]) + + if self.mtp_instance is not None: + mtp_param_dict = dict(self.mtp_instance.named_parameters()) + for mtp_layer_idx in range(self.num_mtp_layers): + self.expert_param_per_layer[self.num_dense_layers + self.num_moe_layers + mtp_layer_idx] = list() + for local_expert_id in range(num_local_expert): + for mtp_layer_idx in range(self.num_mtp_layers): + self.expert_param_per_layer[self.num_dense_layers + self.num_moe_layers + mtp_layer_idx].append([ + mtp_param_dict["model.layers." + str(self.num_dense_layers + self.num_moe_layers + mtp_layer_idx) + + ".mtp_block.mlp.experts." + + name].data[local_expert_id] + for name in self.mtp_expert_weight_names + ]) def get_rank_expert_workload(self) -> torch.Tensor: self.moe_load = self.model.get_all_moe_loads() + if self.mtp_instance is not None: + self.moe_load = torch.cat([self.moe_load, self.mtp_instance.model.get_all_moe_loads().to(device=self.moe_load.device)], dim=0) return self.moe_load def get_init_expert_map(self, num_moe_layers): expert_map = self.model.get_all_expert_map(num_moe_layers) + if self.mtp_instance is not None: + expert_map = torch.cat([expert_map, self.mtp_instance.model.get_all_expert_map().to(device=expert_map.device)], dim=0) if dist.is_initialized(): world_size = dist.get_world_size() @@ -261,9 +305,11 @@ def determine_expert_map_all(self): local_num_experts = self.global_expert_num // self.world_size expert_map_all = torch.full( - (self.num_moe_layers, self.world_size, self.global_expert_num), - -1, - dtype=torch.int32) + (self.num_moe_layers if self.mtp_instance is None else (self.num_moe_layers + self.num_mtp_layers), + self.world_size, + self.global_expert_num), + -1, + dtype=torch.int32) for r in range(self.world_size): if r < self.world_size - 1: @@ -284,6 +330,6 @@ def determine_expert_map_all(self): local_ids = torch.arange(local_count, dtype=torch.int32) expert_map_all[:, r, start:end] = local_ids.unsqueeze(0).expand( - self.num_moe_layers, -1) + self.num_moe_layers if self.mtp_instance is None else (self.num_moe_layers + self.num_mtp_layers), -1) return expert_map_all diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py index f2a5b695060..3ae71ff6869 100644 --- a/vllm_ascend/eplb/eplb_updator.py +++ b/vllm_ascend/eplb/eplb_updator.py @@ -35,9 +35,11 @@ def __init__(self, ascend_config, loader, eplb_process: EplbProcess, self.eplb_process = eplb_process self.shared_dict = self.eplb_process.shared_dict - def set_adaptor(self, adaptor): + def set_adaptor(self, adaptor, num_mtp_layers): self.adaptor = adaptor - self.num_moe_layers = self.adaptor.num_moe_layers + self.num_moe_layers = ( + self.adaptor.num_moe_layers if self.adaptor.mtp_instance is None else self.adaptor.num_moe_layers + num_mtp_layers + ) self.global_expert_num = self.adaptor.global_expert_num def init_eplb(self, expert_map_path, process): @@ -84,6 +86,8 @@ def update_iteration(self): self.expert_map_record_path) self.adaptor.model.clear_all_moe_loads() + if self.adaptor.mtp_instance is not None: + self.adaptor.mtp_instance.model.clear_all_moe_loads() if not self.gate_eplb: self.cur_iterations = 0 diff --git a/vllm_ascend/eplb/utils.py b/vllm_ascend/eplb/utils.py index 61e5735e4dc..cb9bee00a72 100644 --- a/vllm_ascend/eplb/utils.py +++ b/vllm_ascend/eplb/utils.py @@ -19,45 +19,72 @@ import torch +from vllm.model_executor.models.deepseek_mtp import DeepSeekMultiTokenPredictor + def get_expert_map(self, layer_id): - return self.model.layers[layer_id].mlp.experts.get_map() + if not isinstance(self, DeepSeekMultiTokenPredictor): + return self.model.layers[layer_id].mlp.experts.get_map() + else: + return self.layers[str(layer_id)].mtp_block.mlp.experts.get_map() def get_log2phy_map(self, layer_id): - return self.model.layers[layer_id].mlp.experts.get_log2phy_map() - - -def get_all_expert_map(self, num_moe_layers): - all_loads = [] - num_dense_layers = self.num_dense_layers if hasattr( - self, "num_dense_layers") else 0 - for layer_id in range(num_moe_layers): - load_tensor = self.get_expert_map( - layer_id + num_dense_layers) # (num_experts_per_layer,) - all_loads.append(load_tensor) + if not isinstance(self, DeepSeekMultiTokenPredictor): + return self.model.layers[layer_id].mlp.experts.get_log2phy_map() + else: + return self.layers[str(layer_id)].mtp_block.mlp.experts.get_log2phy_map() + + +def get_all_expert_map(self, num_moe_layers=None): + if not isinstance(self, DeepSeekMultiTokenPredictor): + all_loads = [] + num_dense_layers = self.num_dense_layers if hasattr( + self, "num_dense_layers") else 0 + for layer_id in range(num_moe_layers): + load_tensor = self.get_expert_map( + layer_id + num_dense_layers) # (num_experts_per_layer,) + all_loads.append(load_tensor) + else: + all_loads = [] + for layer_id in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers): + load_tensor = self.get_expert_map(layer_id) + all_loads.append(load_tensor) return torch.stack(all_loads, dim=0) def get_all_moe_loads(self): - num_dense_layers = self.num_dense_layers if hasattr( - self, "num_dense_layers") else 0 - all_moe_loads = torch.stack( - [self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \ - for layer_id in range(self.num_moe_layers)], - dim=0 - ) + if not isinstance(self, DeepSeekMultiTokenPredictor): + num_dense_layers = self.num_dense_layers if hasattr( + self, "num_dense_layers") else 0 + all_moe_loads = torch.stack( + [self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \ + for layer_id in range(self.num_moe_layers)], + dim=0 + ) + else: + all_moe_loads = torch.stack( + [self.layers[str(idx)].mtp_block.mlp.experts.moe_load \ + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers)], + dim=0 + ) return all_moe_loads def clear_all_moe_loads(self): - num_dense_layers = self.num_dense_layers if hasattr( - self, "num_dense_layers") else 0 - for layer_id in range(self.num_moe_layers): - self.model.layers[layer_id + - num_dense_layers].mlp.experts.clear_moe_load() - + if not isinstance(self, DeepSeekMultiTokenPredictor): + num_dense_layers = self.num_dense_layers if hasattr( + self, "num_dense_layers") else 0 + for layer_id in range(self.num_moe_layers): + self.model.layers[layer_id + + num_dense_layers].mlp.experts.clear_moe_load() + else: + for layer_id in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers): + self.layers[str(layer_id)].mtp_block.mlp.experts.clear_moe_load() + + def model_register(model, model_config): model.get_expert_map = types.MethodType(get_expert_map, model) @@ -66,12 +93,13 @@ def model_register(model, model_config): model.get_all_moe_loads = types.MethodType(get_all_moe_loads, model) model.clear_all_moe_loads = types.MethodType(clear_all_moe_loads, model) - config = model_config.hf_config + if not isinstance(model, DeepSeekMultiTokenPredictor): + config = model_config.hf_config - if config.model_type == "qwen3_moe": - model.num_moe_layers = config.num_hidden_layers - elif config.model_type == "deepseek_v2" or config.model_type == "deepseek_v3": - model.num_dense_layers = config.first_k_dense_replace - model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers - else: - raise NotImplementedError("EPLB is not supported.") + if config.model_type == "qwen3_moe": + model.num_moe_layers = config.num_hidden_layers + elif config.model_type == "deepseek_v2" or config.model_type == "deepseek_v3": + model.num_dense_layers = config.first_k_dense_replace + model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers + else: + raise NotImplementedError("EPLB is not supported.") diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 741f0d2e7f6..4c11a08f9c6 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -62,6 +62,7 @@ from vllm.model_executor.models.interfaces import (SupportsMultiModal, supports_mrope, supports_transcription) +from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP from vllm.model_executor.models.interfaces_base import ( VllmModelForPooling, is_pooling_model, is_text_generation_model) from vllm.multimodal import MULTIMODAL_REGISTRY @@ -278,6 +279,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): 'decode_max_num_seqs', 0) self.max_num_reqs = max(self.scheduler_config.max_num_seqs, decode_max_num_seqs) + self.mtp_instance = None if self.pcp_size > 1: self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, @@ -3007,6 +3009,8 @@ def dummy_compute_logits(hidden_states): hidden_states[dummy_indices]) if self.in_profile_run and self.dynamic_eplb: self.model.clear_all_moe_loads() + if self.mtp_instance is not None: + self.drafter.model.model.clear_all_moe_loads() if not self.in_profile_run and self.dynamic_eplb: self.eplb_updator.take_update_info_from_eplb_process() self.eplb_updator.forward_end() @@ -3129,9 +3133,14 @@ def _dummy_pooler_run( def eplb_warmup(self): if self.dynamic_eplb and not self.is_eplb_warmuped: self.is_eplb_warmuped = True - self.eplb_adaptor = VllmEplbAdaptor(model=self.model) + num_mtp_layers = self.mtp_instance.model.num_mtp_layers if self.mtp_instance is not None else 0 + self.eplb_adaptor = VllmEplbAdaptor( + model=self.model, + mtp_instance=self.mtp_instance, + num_mtp_layers=num_mtp_layers + ) self.eplb_loader.set_adator(self.eplb_adaptor) - self.eplb_updator.set_adaptor(self.eplb_adaptor) + self.eplb_updator.set_adaptor(self.eplb_adaptor, num_mtp_layers) self.eplb_updator.warm_up_eplb() def load_model(self) -> None: @@ -3140,7 +3149,7 @@ def load_model(self) -> None: with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) if self.dynamic_eplb: - model_register(self.model, self.model_config) + model_register(self.model, self.model_config) if is_310p(): from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -3154,6 +3163,17 @@ def load_model(self) -> None: if self.drafter: logger.info("Loading drafter model...") self.drafter.load_model(self.model) + if self.speculative_config and self.speculative_config.method == 'deepseek_mtp': + assert isinstance(self.drafter, MtpProposer), \ + f"drafter type wrong: {type(self.drafter)}" + assert isinstance(self.drafter.model, (DeepSeekMTP, ACLGraphWrapper)), \ + f"drafter type wrong: {type(self.drafter)}, only suport DeepSeekMTP or ACLGraphWrapper" + if isinstance(self.drafter.model, DeepSeekMTP): + self.mtp_instance = self.drafter.model + elif isinstance(self.drafter.model, ACLGraphWrapper): + self.mtp_instance = self.drafter.model.unwrap() + model_register(self.mtp_instance.model, self.vllm_config) + if self.drafter.name == SpecDcodeType.EAGLE3: self.model.set_aux_hidden_state_layers( self.model.get_eagle3_aux_hidden_state_layers())