-
Notifications
You must be signed in to change notification settings - Fork 622
[Feat] Enable EPLB to support MTP layers #4598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,45 +19,72 @@ | |
|
|
||
| import torch | ||
|
|
||
| from vllm.model_executor.models.deepseek_mtp import DeepSeekMultiTokenPredictor | ||
|
|
||
|
|
||
| def get_expert_map(self, layer_id): | ||
| return self.model.layers[layer_id].mlp.experts.get_map() | ||
| if not isinstance(self, DeepSeekMultiTokenPredictor): | ||
| return self.model.layers[layer_id].mlp.experts.get_map() | ||
| else: | ||
| return self.layers[str(layer_id)].mtp_block.mlp.experts.get_map() | ||
|
|
||
|
|
||
| def get_log2phy_map(self, layer_id): | ||
| return self.model.layers[layer_id].mlp.experts.get_log2phy_map() | ||
|
|
||
|
|
||
| def get_all_expert_map(self, num_moe_layers): | ||
| all_loads = [] | ||
| num_dense_layers = self.num_dense_layers if hasattr( | ||
| self, "num_dense_layers") else 0 | ||
| for layer_id in range(num_moe_layers): | ||
| load_tensor = self.get_expert_map( | ||
| layer_id + num_dense_layers) # (num_experts_per_layer,) | ||
| all_loads.append(load_tensor) | ||
| if not isinstance(self, DeepSeekMultiTokenPredictor): | ||
| return self.model.layers[layer_id].mlp.experts.get_log2phy_map() | ||
| else: | ||
| return self.layers[str(layer_id)].mtp_block.mlp.experts.get_log2phy_map() | ||
|
|
||
|
|
||
| def get_all_expert_map(self, num_moe_layers=None): | ||
| if not isinstance(self, DeepSeekMultiTokenPredictor): | ||
| all_loads = [] | ||
| num_dense_layers = self.num_dense_layers if hasattr( | ||
| self, "num_dense_layers") else 0 | ||
| for layer_id in range(num_moe_layers): | ||
| load_tensor = self.get_expert_map( | ||
| layer_id + num_dense_layers) # (num_experts_per_layer,) | ||
| all_loads.append(load_tensor) | ||
| else: | ||
| all_loads = [] | ||
| for layer_id in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers): | ||
| load_tensor = self.get_expert_map(layer_id) | ||
| all_loads.append(load_tensor) | ||
|
|
||
| return torch.stack(all_loads, dim=0) | ||
|
|
||
|
|
||
| def get_all_moe_loads(self): | ||
| num_dense_layers = self.num_dense_layers if hasattr( | ||
| self, "num_dense_layers") else 0 | ||
| all_moe_loads = torch.stack( | ||
| [self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \ | ||
| for layer_id in range(self.num_moe_layers)], | ||
| dim=0 | ||
| ) | ||
| if not isinstance(self, DeepSeekMultiTokenPredictor): | ||
| num_dense_layers = self.num_dense_layers if hasattr( | ||
| self, "num_dense_layers") else 0 | ||
| all_moe_loads = torch.stack( | ||
| [self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \ | ||
| for layer_id in range(self.num_moe_layers)], | ||
| dim=0 | ||
| ) | ||
| else: | ||
| all_moe_loads = torch.stack( | ||
| [self.layers[str(idx)].mtp_block.mlp.experts.moe_load \ | ||
| for idx in range(self.mtp_start_layer_idx, | ||
| self.mtp_start_layer_idx + self.num_mtp_layers)], | ||
| dim=0 | ||
| ) | ||
| return all_moe_loads | ||
|
|
||
|
|
||
| def clear_all_moe_loads(self): | ||
| num_dense_layers = self.num_dense_layers if hasattr( | ||
| self, "num_dense_layers") else 0 | ||
| for layer_id in range(self.num_moe_layers): | ||
| self.model.layers[layer_id + | ||
| num_dense_layers].mlp.experts.clear_moe_load() | ||
|
|
||
| if not isinstance(self, DeepSeekMultiTokenPredictor): | ||
| num_dense_layers = self.num_dense_layers if hasattr( | ||
| self, "num_dense_layers") else 0 | ||
| for layer_id in range(self.num_moe_layers): | ||
| self.model.layers[layer_id + | ||
| num_dense_layers].mlp.experts.clear_moe_load() | ||
| else: | ||
| for layer_id in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers): | ||
| self.layers[str(layer_id)].mtp_block.mlp.experts.clear_moe_load() | ||
|
|
||
|
Comment on lines
23
to
+86
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The functions Consider refactoring this by introducing a helper method that abstracts away the difference in accessing the expert-related modules between the main model and the MTP model. This would centralize the model-specific logic and make the main functions cleaner and more generic. For example: def _get_experts_for_layer(model, layer_id):
if not isinstance(model, DeepSeekMultiTokenPredictor):
return model.model.layers[layer_id].mlp.experts
else:
# This is still model-specific, but at least it's centralized.
return model.layers[str(layer_id)].mtp_block.mlp.experts
def get_expert_map(self, layer_id):
experts = _get_experts_for_layer(self, layer_id)
return experts.get_map()
# ... and so on for other functions |
||
|
|
||
|
|
||
| def model_register(model, model_config): | ||
| model.get_expert_map = types.MethodType(get_expert_map, model) | ||
|
|
@@ -66,12 +93,13 @@ def model_register(model, model_config): | |
| model.get_all_moe_loads = types.MethodType(get_all_moe_loads, model) | ||
| model.clear_all_moe_loads = types.MethodType(clear_all_moe_loads, model) | ||
|
|
||
| config = model_config.hf_config | ||
| if not isinstance(model, DeepSeekMultiTokenPredictor): | ||
| config = model_config.hf_config | ||
|
|
||
| if config.model_type == "qwen3_moe": | ||
| model.num_moe_layers = config.num_hidden_layers | ||
| elif config.model_type == "deepseek_v2" or config.model_type == "deepseek_v3": | ||
| model.num_dense_layers = config.first_k_dense_replace | ||
| model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers | ||
| else: | ||
| raise NotImplementedError("EPLB is not supported.") | ||
| if config.model_type == "qwen3_moe": | ||
| model.num_moe_layers = config.num_hidden_layers | ||
| elif config.model_type == "deepseek_v2" or config.model_type == "deepseek_v3": | ||
| model.num_dense_layers = config.first_k_dense_replace | ||
| model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers | ||
| else: | ||
| raise NotImplementedError("EPLB is not supported.") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The construction of parameter names for MTP layers is hardcoded and very specific to the current model structure. This makes the code brittle and difficult to maintain or extend to other MTP models. While the suggestion improves readability and removes a redundant calculation, the core issue of the hardcoded path remains. Consider making this more robust, for example by having the MTP model itself expose a method to get expert parameter names for a given layer.