Skip to content

Conversation

@dsxsteven
Copy link
Contributor

@dsxsteven dsxsteven commented Dec 1, 2025

What this PR does / why we need it?

This PR enhances EPLB to support one or multiple MTP layers. Previously, EPLB only supported the main model. Now, it can handle num_speculative_tokens=1 or num_speculative_tokens > 1.

Does this PR introduce any user-facing change?

No, this PR does not introduce any user-facing changes.

How was this patch tested?

Co-authored-by: Skywalker-EP 173723846@qq.com, pop1120138272@icloud.com, dsxsteven@sina.com

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully extends EPLB to support MTP layers, which is a great enhancement. The changes are comprehensive, touching the adaptor, updator, utilities, and model runner to integrate MTP layer handling for load balancing.

My review focuses on improving maintainability. I've identified two areas where the code could be made more robust and less model-specific:

  1. In vllm_adaptor.py, parameter names for MTP layers are constructed using hardcoded strings, which is brittle.
  2. In utils.py, there's significant code duplication and hardcoded attribute paths for handling DeepSeekMultiTokenPredictor instances across several functions.

Addressing these points will make the implementation more resilient to future model structure changes and easier to extend for other MTP-like models.

Comment on lines +140 to +145
self.expert_param_per_layer[self.num_dense_layers + self.num_moe_layers + mtp_layer_idx].append([
mtp_param_dict["model.layers." + str(self.num_dense_layers + self.num_moe_layers + mtp_layer_idx) +
".mtp_block.mlp.experts." +
name].data[local_expert_id]
for name in self.mtp_expert_weight_names
])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The construction of parameter names for MTP layers is hardcoded and very specific to the current model structure. This makes the code brittle and difficult to maintain or extend to other MTP models. While the suggestion improves readability and removes a redundant calculation, the core issue of the hardcoded path remains. Consider making this more robust, for example by having the MTP model itself expose a method to get expert parameter names for a given layer.

                    layer_idx = self.num_dense_layers + self.num_moe_layers + mtp_layer_idx
                    self.expert_param_per_layer[layer_idx].append([
                        mtp_param_dict[f"model.layers.{layer_idx}.mtp_block.mlp.experts.{name}"].data[local_expert_id]
                        for name in self.mtp_expert_weight_names
                    ])

Comment on lines 25 to +86
def get_expert_map(self, layer_id):
return self.model.layers[layer_id].mlp.experts.get_map()
if not isinstance(self, DeepSeekMultiTokenPredictor):
return self.model.layers[layer_id].mlp.experts.get_map()
else:
return self.layers[str(layer_id)].mtp_block.mlp.experts.get_map()


def get_log2phy_map(self, layer_id):
return self.model.layers[layer_id].mlp.experts.get_log2phy_map()


def get_all_expert_map(self, num_moe_layers):
all_loads = []
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
for layer_id in range(num_moe_layers):
load_tensor = self.get_expert_map(
layer_id + num_dense_layers) # (num_experts_per_layer,)
all_loads.append(load_tensor)
if not isinstance(self, DeepSeekMultiTokenPredictor):
return self.model.layers[layer_id].mlp.experts.get_log2phy_map()
else:
return self.layers[str(layer_id)].mtp_block.mlp.experts.get_log2phy_map()


def get_all_expert_map(self, num_moe_layers=None):
if not isinstance(self, DeepSeekMultiTokenPredictor):
all_loads = []
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
for layer_id in range(num_moe_layers):
load_tensor = self.get_expert_map(
layer_id + num_dense_layers) # (num_experts_per_layer,)
all_loads.append(load_tensor)
else:
all_loads = []
for layer_id in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers):
load_tensor = self.get_expert_map(layer_id)
all_loads.append(load_tensor)

return torch.stack(all_loads, dim=0)


def get_all_moe_loads(self):
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
all_moe_loads = torch.stack(
[self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \
for layer_id in range(self.num_moe_layers)],
dim=0
)
if not isinstance(self, DeepSeekMultiTokenPredictor):
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
all_moe_loads = torch.stack(
[self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \
for layer_id in range(self.num_moe_layers)],
dim=0
)
else:
all_moe_loads = torch.stack(
[self.layers[str(idx)].mtp_block.mlp.experts.moe_load \
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)],
dim=0
)
return all_moe_loads


def clear_all_moe_loads(self):
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
for layer_id in range(self.num_moe_layers):
self.model.layers[layer_id +
num_dense_layers].mlp.experts.clear_moe_load()

if not isinstance(self, DeepSeekMultiTokenPredictor):
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
for layer_id in range(self.num_moe_layers):
self.model.layers[layer_id +
num_dense_layers].mlp.experts.clear_moe_load()
else:
for layer_id in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers):
self.layers[str(layer_id)].mtp_block.mlp.experts.clear_moe_load()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The functions get_expert_map, get_log2phy_map, get_all_expert_map, get_all_moe_loads, and clear_all_moe_loads all contain repetitive if not isinstance(self, DeepSeekMultiTokenPredictor): ... else: ... logic. The else blocks also use hardcoded attribute access paths (e.g., .mtp_block.mlp.experts), which is brittle and not easily extensible to other MTP models. This duplicated logic and hardcoding makes the code difficult to maintain.

Consider refactoring this by introducing a helper method that abstracts away the difference in accessing the expert-related modules between the main model and the MTP model. This would centralize the model-specific logic and make the main functions cleaner and more generic.

For example:

def _get_experts_for_layer(model, layer_id):
    if not isinstance(model, DeepSeekMultiTokenPredictor):
        return model.model.layers[layer_id].mlp.experts
    else:
        # This is still model-specific, but at least it's centralized.
        return model.layers[str(layer_id)].mtp_block.mlp.experts

def get_expert_map(self, layer_id):
    experts = _get_experts_for_layer(self, layer_id)
    return experts.get_map()

# ... and so on for other functions

@dsxsteven dsxsteven changed the title [feature] Enable EPLB to support MTP layers [Feat] Enable EPLB to support MTP layers Dec 1, 2025
@github-actions
Copy link

github-actions bot commented Dec 1, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@github-actions
Copy link

github-actions bot commented Dec 1, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants