Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 52 additions & 6 deletions vllm_ascend/eplb/adaptor/vllm_adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,29 @@
import torch
import torch.distributed as dist
from vllm.logger import logger
from vllm.config import get_current_vllm_config

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.eplb.adaptor.abstract_adaptor import EplbAdaptor


class VllmEplbAdaptor(EplbAdaptor):

def __init__(self, model, **args):
def __init__(self, model, mtp_instance=None, num_mtp_layers=0, **args):
super().__init__(**args)
self.model = model
self.rank_id = dist.get_rank()
self.world_size = dist.get_world_size()
self.param_dict = dict(self.model.named_parameters())
self.mtp_instance = mtp_instance
self.num_mtp_layers = num_mtp_layers
if self.model.config.model_type == "qwen3_moe":
self.num_dense_layers = 0
self.global_expert_num = self.model.config.num_experts
else:
self.num_dense_layers = self.model.config.first_k_dense_replace
self.global_expert_num = self.model.config.n_routed_experts
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers # MTP not included
self.init_redundancy_expert = get_ascend_config(
).init_redundancy_expert

Expand All @@ -53,6 +56,19 @@ def __init__(self, model, **args):
else:
self.expert_weight_names = ["w13_weight", "w2_weight"]

# TODO: init self.mtp_expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
if self.mtp_instance is not None:
if any("w13_weight_offset" in name for name, _ in self.mtp_instance.named_parameters()):
self.mtp_expert_weight_names = [
"w13_weight", "w2_weight", "w13_weight_scale",
"w13_weight_offset", "w2_weight_scale", "w2_weight_offset"
]
else:
self.mtp_expert_weight_names = ["w13_weight", "w2_weight"]
else:
self.mtp_expert_weight_names = []


self.expert_map_per_layer = dict(
) # reference to expert map on device for expert map update
self.expert_map_per_layer_cpu = dict(
Expand All @@ -61,6 +77,12 @@ def __init__(self, model, **args):
self.expert_map_per_layer[self.num_dense_layers + layer_idx] = \
self.model.get_expert_map(self.num_dense_layers + layer_idx)

# Currently, MTP only support one layer.
if self.mtp_instance is not None:
for mtp_layer_idx in range(self.num_mtp_layers):
self.expert_map_per_layer[self.num_dense_layers + self.num_moe_layers + mtp_layer_idx] = \
self.mtp_instance.model.get_expert_map(self.num_dense_layers + self.num_moe_layers + mtp_layer_idx)

# TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved
num_buffer_tensor = torch.where(
self.expert_map_per_layer[self.num_dense_layers] != -1)[0].numel()
Expand All @@ -76,6 +98,11 @@ def __init__(self, model, **args):
for layer_idx in range(self.num_moe_layers):
self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = \
self.model.get_log2phy_map(self.num_dense_layers + layer_idx)

if self.mtp_instance is not None:
for mtp_layer_idx in range(self.num_mtp_layers):
self.log2phy_map_per_layer[self.num_dense_layers + self.num_moe_layers + mtp_layer_idx] = \
self.mtp_instance.model.get_log2phy_map(self.num_dense_layers + self.num_moe_layers + mtp_layer_idx)

self.all_topk_ids = []

Expand Down Expand Up @@ -103,13 +130,30 @@ def init_expert_param_per_layer(self):
name].data[local_expert_id]
for name in self.expert_weight_names
])

if self.mtp_instance is not None:
mtp_param_dict = dict(self.mtp_instance.named_parameters())
for mtp_layer_idx in range(self.num_mtp_layers):
self.expert_param_per_layer[self.num_dense_layers + self.num_moe_layers + mtp_layer_idx] = list()
for local_expert_id in range(num_local_expert):
for mtp_layer_idx in range(self.num_mtp_layers):
self.expert_param_per_layer[self.num_dense_layers + self.num_moe_layers + mtp_layer_idx].append([
mtp_param_dict["model.layers." + str(self.num_dense_layers + self.num_moe_layers + mtp_layer_idx) +
".mtp_block.mlp.experts." +
name].data[local_expert_id]
for name in self.mtp_expert_weight_names
])
Comment on lines +140 to +145
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The construction of parameter names for MTP layers is hardcoded and very specific to the current model structure. This makes the code brittle and difficult to maintain or extend to other MTP models. While the suggestion improves readability and removes a redundant calculation, the core issue of the hardcoded path remains. Consider making this more robust, for example by having the MTP model itself expose a method to get expert parameter names for a given layer.

                    layer_idx = self.num_dense_layers + self.num_moe_layers + mtp_layer_idx
                    self.expert_param_per_layer[layer_idx].append([
                        mtp_param_dict[f"model.layers.{layer_idx}.mtp_block.mlp.experts.{name}"].data[local_expert_id]
                        for name in self.mtp_expert_weight_names
                    ])


def get_rank_expert_workload(self) -> torch.Tensor:
self.moe_load = self.model.get_all_moe_loads()
if self.mtp_instance is not None:
self.moe_load = torch.cat([self.moe_load, self.mtp_instance.model.get_all_moe_loads().to(device=self.moe_load.device)], dim=0)
return self.moe_load

def get_init_expert_map(self, num_moe_layers):
expert_map = self.model.get_all_expert_map(num_moe_layers)
if self.mtp_instance is not None:
expert_map = torch.cat([expert_map, self.mtp_instance.model.get_all_expert_map().to(device=expert_map.device)], dim=0)
if dist.is_initialized():
world_size = dist.get_world_size()

Expand Down Expand Up @@ -261,9 +305,11 @@ def determine_expert_map_all(self):
local_num_experts = self.global_expert_num // self.world_size

expert_map_all = torch.full(
(self.num_moe_layers, self.world_size, self.global_expert_num),
-1,
dtype=torch.int32)
(self.num_moe_layers if self.mtp_instance is None else (self.num_moe_layers + self.num_mtp_layers),
self.world_size,
self.global_expert_num),
-1,
dtype=torch.int32)

for r in range(self.world_size):
if r < self.world_size - 1:
Expand All @@ -284,6 +330,6 @@ def determine_expert_map_all(self):

local_ids = torch.arange(local_count, dtype=torch.int32)
expert_map_all[:, r, start:end] = local_ids.unsqueeze(0).expand(
self.num_moe_layers, -1)
self.num_moe_layers if self.mtp_instance is None else (self.num_moe_layers + self.num_mtp_layers), -1)

return expert_map_all
8 changes: 6 additions & 2 deletions vllm_ascend/eplb/eplb_updator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def __init__(self, ascend_config, loader, eplb_process: EplbProcess,
self.eplb_process = eplb_process
self.shared_dict = self.eplb_process.shared_dict

def set_adaptor(self, adaptor):
def set_adaptor(self, adaptor, num_mtp_layers):
self.adaptor = adaptor
self.num_moe_layers = self.adaptor.num_moe_layers
self.num_moe_layers = (
self.adaptor.num_moe_layers if self.adaptor.mtp_instance is None else self.adaptor.num_moe_layers + num_mtp_layers
)
self.global_expert_num = self.adaptor.global_expert_num

def init_eplb(self, expert_map_path, process):
Expand Down Expand Up @@ -84,6 +86,8 @@ def update_iteration(self):
self.expert_map_record_path)

self.adaptor.model.clear_all_moe_loads()
if self.adaptor.mtp_instance is not None:
self.adaptor.mtp_instance.model.clear_all_moe_loads()
if not self.gate_eplb:
self.cur_iterations = 0

Expand Down
94 changes: 61 additions & 33 deletions vllm_ascend/eplb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,45 +19,72 @@

import torch

from vllm.model_executor.models.deepseek_mtp import DeepSeekMultiTokenPredictor


def get_expert_map(self, layer_id):
return self.model.layers[layer_id].mlp.experts.get_map()
if not isinstance(self, DeepSeekMultiTokenPredictor):
return self.model.layers[layer_id].mlp.experts.get_map()
else:
return self.layers[str(layer_id)].mtp_block.mlp.experts.get_map()


def get_log2phy_map(self, layer_id):
return self.model.layers[layer_id].mlp.experts.get_log2phy_map()


def get_all_expert_map(self, num_moe_layers):
all_loads = []
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
for layer_id in range(num_moe_layers):
load_tensor = self.get_expert_map(
layer_id + num_dense_layers) # (num_experts_per_layer,)
all_loads.append(load_tensor)
if not isinstance(self, DeepSeekMultiTokenPredictor):
return self.model.layers[layer_id].mlp.experts.get_log2phy_map()
else:
return self.layers[str(layer_id)].mtp_block.mlp.experts.get_log2phy_map()


def get_all_expert_map(self, num_moe_layers=None):
if not isinstance(self, DeepSeekMultiTokenPredictor):
all_loads = []
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
for layer_id in range(num_moe_layers):
load_tensor = self.get_expert_map(
layer_id + num_dense_layers) # (num_experts_per_layer,)
all_loads.append(load_tensor)
else:
all_loads = []
for layer_id in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers):
load_tensor = self.get_expert_map(layer_id)
all_loads.append(load_tensor)

return torch.stack(all_loads, dim=0)


def get_all_moe_loads(self):
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
all_moe_loads = torch.stack(
[self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \
for layer_id in range(self.num_moe_layers)],
dim=0
)
if not isinstance(self, DeepSeekMultiTokenPredictor):
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
all_moe_loads = torch.stack(
[self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \
for layer_id in range(self.num_moe_layers)],
dim=0
)
else:
all_moe_loads = torch.stack(
[self.layers[str(idx)].mtp_block.mlp.experts.moe_load \
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)],
dim=0
)
return all_moe_loads


def clear_all_moe_loads(self):
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
for layer_id in range(self.num_moe_layers):
self.model.layers[layer_id +
num_dense_layers].mlp.experts.clear_moe_load()

if not isinstance(self, DeepSeekMultiTokenPredictor):
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
for layer_id in range(self.num_moe_layers):
self.model.layers[layer_id +
num_dense_layers].mlp.experts.clear_moe_load()
else:
for layer_id in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers):
self.layers[str(layer_id)].mtp_block.mlp.experts.clear_moe_load()

Comment on lines 23 to +86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The functions get_expert_map, get_log2phy_map, get_all_expert_map, get_all_moe_loads, and clear_all_moe_loads all contain repetitive if not isinstance(self, DeepSeekMultiTokenPredictor): ... else: ... logic. The else blocks also use hardcoded attribute access paths (e.g., .mtp_block.mlp.experts), which is brittle and not easily extensible to other MTP models. This duplicated logic and hardcoding makes the code difficult to maintain.

Consider refactoring this by introducing a helper method that abstracts away the difference in accessing the expert-related modules between the main model and the MTP model. This would centralize the model-specific logic and make the main functions cleaner and more generic.

For example:

def _get_experts_for_layer(model, layer_id):
    if not isinstance(model, DeepSeekMultiTokenPredictor):
        return model.model.layers[layer_id].mlp.experts
    else:
        # This is still model-specific, but at least it's centralized.
        return model.layers[str(layer_id)].mtp_block.mlp.experts

def get_expert_map(self, layer_id):
    experts = _get_experts_for_layer(self, layer_id)
    return experts.get_map()

# ... and so on for other functions



def model_register(model, model_config):
model.get_expert_map = types.MethodType(get_expert_map, model)
Expand All @@ -66,12 +93,13 @@ def model_register(model, model_config):
model.get_all_moe_loads = types.MethodType(get_all_moe_loads, model)
model.clear_all_moe_loads = types.MethodType(clear_all_moe_loads, model)

config = model_config.hf_config
if not isinstance(model, DeepSeekMultiTokenPredictor):
config = model_config.hf_config

if config.model_type == "qwen3_moe":
model.num_moe_layers = config.num_hidden_layers
elif config.model_type == "deepseek_v2" or config.model_type == "deepseek_v3":
model.num_dense_layers = config.first_k_dense_replace
model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers
else:
raise NotImplementedError("EPLB is not supported.")
if config.model_type == "qwen3_moe":
model.num_moe_layers = config.num_hidden_layers
elif config.model_type == "deepseek_v2" or config.model_type == "deepseek_v3":
model.num_dense_layers = config.first_k_dense_replace
model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers
else:
raise NotImplementedError("EPLB is not supported.")
26 changes: 23 additions & 3 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
supports_mrope,
supports_transcription)
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling, is_pooling_model, is_text_generation_model)
from vllm.multimodal import MULTIMODAL_REGISTRY
Expand Down Expand Up @@ -278,6 +279,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
'decode_max_num_seqs', 0)
self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
decode_max_num_seqs)
self.mtp_instance = None
if self.pcp_size > 1:
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
Expand Down Expand Up @@ -3007,6 +3009,8 @@ def dummy_compute_logits(hidden_states):
hidden_states[dummy_indices])
if self.in_profile_run and self.dynamic_eplb:
self.model.clear_all_moe_loads()
if self.mtp_instance is not None:
self.drafter.model.model.clear_all_moe_loads()
if not self.in_profile_run and self.dynamic_eplb:
self.eplb_updator.take_update_info_from_eplb_process()
self.eplb_updator.forward_end()
Expand Down Expand Up @@ -3129,9 +3133,14 @@ def _dummy_pooler_run(
def eplb_warmup(self):
if self.dynamic_eplb and not self.is_eplb_warmuped:
self.is_eplb_warmuped = True
self.eplb_adaptor = VllmEplbAdaptor(model=self.model)
num_mtp_layers = self.mtp_instance.model.num_mtp_layers if self.mtp_instance is not None else 0
self.eplb_adaptor = VllmEplbAdaptor(
model=self.model,
mtp_instance=self.mtp_instance,
num_mtp_layers=num_mtp_layers
)
self.eplb_loader.set_adator(self.eplb_adaptor)
self.eplb_updator.set_adaptor(self.eplb_adaptor)
self.eplb_updator.set_adaptor(self.eplb_adaptor, num_mtp_layers)
self.eplb_updator.warm_up_eplb()

def load_model(self) -> None:
Expand All @@ -3140,7 +3149,7 @@ def load_model(self) -> None:
with DeviceMemoryProfiler() as m: # noqa: SIM117
self.model = get_model(vllm_config=self.vllm_config)
if self.dynamic_eplb:
model_register(self.model, self.model_config)
model_register(self.model, self.model_config)
if is_310p():
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, QKVParallelLinear,
Expand All @@ -3154,6 +3163,17 @@ def load_model(self) -> None:
if self.drafter:
logger.info("Loading drafter model...")
self.drafter.load_model(self.model)
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
assert isinstance(self.drafter, MtpProposer), \
f"drafter type wrong: {type(self.drafter)}"
assert isinstance(self.drafter.model, (DeepSeekMTP, ACLGraphWrapper)), \
f"drafter type wrong: {type(self.drafter)}, only suport DeepSeekMTP or ACLGraphWrapper"
if isinstance(self.drafter.model, DeepSeekMTP):
self.mtp_instance = self.drafter.model
elif isinstance(self.drafter.model, ACLGraphWrapper):
self.mtp_instance = self.drafter.model.unwrap()
model_register(self.mtp_instance.model, self.vllm_config)

if self.drafter.name == SpecDcodeType.EAGLE3:
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers())
Expand Down
Loading