-
Notifications
You must be signed in to change notification settings - Fork 622
[Feat] Enable EPLB to support MTP layers #4598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request successfully extends EPLB to support MTP layers, which is a great enhancement. The changes are comprehensive, touching the adaptor, updator, utilities, and model runner to integrate MTP layer handling for load balancing.
My review focuses on improving maintainability. I've identified two areas where the code could be made more robust and less model-specific:
- In
vllm_adaptor.py, parameter names for MTP layers are constructed using hardcoded strings, which is brittle. - In
utils.py, there's significant code duplication and hardcoded attribute paths for handlingDeepSeekMultiTokenPredictorinstances across several functions.
Addressing these points will make the implementation more resilient to future model structure changes and easier to extend for other MTP-like models.
| self.expert_param_per_layer[self.num_dense_layers + self.num_moe_layers + mtp_layer_idx].append([ | ||
| mtp_param_dict["model.layers." + str(self.num_dense_layers + self.num_moe_layers + mtp_layer_idx) + | ||
| ".mtp_block.mlp.experts." + | ||
| name].data[local_expert_id] | ||
| for name in self.mtp_expert_weight_names | ||
| ]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The construction of parameter names for MTP layers is hardcoded and very specific to the current model structure. This makes the code brittle and difficult to maintain or extend to other MTP models. While the suggestion improves readability and removes a redundant calculation, the core issue of the hardcoded path remains. Consider making this more robust, for example by having the MTP model itself expose a method to get expert parameter names for a given layer.
layer_idx = self.num_dense_layers + self.num_moe_layers + mtp_layer_idx
self.expert_param_per_layer[layer_idx].append([
mtp_param_dict[f"model.layers.{layer_idx}.mtp_block.mlp.experts.{name}"].data[local_expert_id]
for name in self.mtp_expert_weight_names
])| def get_expert_map(self, layer_id): | ||
| return self.model.layers[layer_id].mlp.experts.get_map() | ||
| if not isinstance(self, DeepSeekMultiTokenPredictor): | ||
| return self.model.layers[layer_id].mlp.experts.get_map() | ||
| else: | ||
| return self.layers[str(layer_id)].mtp_block.mlp.experts.get_map() | ||
|
|
||
|
|
||
| def get_log2phy_map(self, layer_id): | ||
| return self.model.layers[layer_id].mlp.experts.get_log2phy_map() | ||
|
|
||
|
|
||
| def get_all_expert_map(self, num_moe_layers): | ||
| all_loads = [] | ||
| num_dense_layers = self.num_dense_layers if hasattr( | ||
| self, "num_dense_layers") else 0 | ||
| for layer_id in range(num_moe_layers): | ||
| load_tensor = self.get_expert_map( | ||
| layer_id + num_dense_layers) # (num_experts_per_layer,) | ||
| all_loads.append(load_tensor) | ||
| if not isinstance(self, DeepSeekMultiTokenPredictor): | ||
| return self.model.layers[layer_id].mlp.experts.get_log2phy_map() | ||
| else: | ||
| return self.layers[str(layer_id)].mtp_block.mlp.experts.get_log2phy_map() | ||
|
|
||
|
|
||
| def get_all_expert_map(self, num_moe_layers=None): | ||
| if not isinstance(self, DeepSeekMultiTokenPredictor): | ||
| all_loads = [] | ||
| num_dense_layers = self.num_dense_layers if hasattr( | ||
| self, "num_dense_layers") else 0 | ||
| for layer_id in range(num_moe_layers): | ||
| load_tensor = self.get_expert_map( | ||
| layer_id + num_dense_layers) # (num_experts_per_layer,) | ||
| all_loads.append(load_tensor) | ||
| else: | ||
| all_loads = [] | ||
| for layer_id in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers): | ||
| load_tensor = self.get_expert_map(layer_id) | ||
| all_loads.append(load_tensor) | ||
|
|
||
| return torch.stack(all_loads, dim=0) | ||
|
|
||
|
|
||
| def get_all_moe_loads(self): | ||
| num_dense_layers = self.num_dense_layers if hasattr( | ||
| self, "num_dense_layers") else 0 | ||
| all_moe_loads = torch.stack( | ||
| [self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \ | ||
| for layer_id in range(self.num_moe_layers)], | ||
| dim=0 | ||
| ) | ||
| if not isinstance(self, DeepSeekMultiTokenPredictor): | ||
| num_dense_layers = self.num_dense_layers if hasattr( | ||
| self, "num_dense_layers") else 0 | ||
| all_moe_loads = torch.stack( | ||
| [self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \ | ||
| for layer_id in range(self.num_moe_layers)], | ||
| dim=0 | ||
| ) | ||
| else: | ||
| all_moe_loads = torch.stack( | ||
| [self.layers[str(idx)].mtp_block.mlp.experts.moe_load \ | ||
| for idx in range(self.mtp_start_layer_idx, | ||
| self.mtp_start_layer_idx + self.num_mtp_layers)], | ||
| dim=0 | ||
| ) | ||
| return all_moe_loads | ||
|
|
||
|
|
||
| def clear_all_moe_loads(self): | ||
| num_dense_layers = self.num_dense_layers if hasattr( | ||
| self, "num_dense_layers") else 0 | ||
| for layer_id in range(self.num_moe_layers): | ||
| self.model.layers[layer_id + | ||
| num_dense_layers].mlp.experts.clear_moe_load() | ||
|
|
||
| if not isinstance(self, DeepSeekMultiTokenPredictor): | ||
| num_dense_layers = self.num_dense_layers if hasattr( | ||
| self, "num_dense_layers") else 0 | ||
| for layer_id in range(self.num_moe_layers): | ||
| self.model.layers[layer_id + | ||
| num_dense_layers].mlp.experts.clear_moe_load() | ||
| else: | ||
| for layer_id in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers): | ||
| self.layers[str(layer_id)].mtp_block.mlp.experts.clear_moe_load() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The functions get_expert_map, get_log2phy_map, get_all_expert_map, get_all_moe_loads, and clear_all_moe_loads all contain repetitive if not isinstance(self, DeepSeekMultiTokenPredictor): ... else: ... logic. The else blocks also use hardcoded attribute access paths (e.g., .mtp_block.mlp.experts), which is brittle and not easily extensible to other MTP models. This duplicated logic and hardcoding makes the code difficult to maintain.
Consider refactoring this by introducing a helper method that abstracts away the difference in accessing the expert-related modules between the main model and the MTP model. This would centralize the model-specific logic and make the main functions cleaner and more generic.
For example:
def _get_experts_for_layer(model, layer_id):
if not isinstance(model, DeepSeekMultiTokenPredictor):
return model.model.layers[layer_id].mlp.experts
else:
# This is still model-specific, but at least it's centralized.
return model.layers[str(layer_id)].mtp_block.mlp.experts
def get_expert_map(self, layer_id):
experts = _get_experts_for_layer(self, layer_id)
return experts.get_map()
# ... and so on for other functions|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
What this PR does / why we need it?
This PR enhances EPLB to support one or multiple MTP layers. Previously, EPLB only supported the main model. Now, it can handle num_speculative_tokens=1 or num_speculative_tokens > 1.
Does this PR introduce any user-facing change?
No, this PR does not introduce any user-facing changes.
How was this patch tested?
Co-authored-by: Skywalker-EP 173723846@qq.com, pop1120138272@icloud.com, dsxsteven@sina.com