-
Notifications
You must be signed in to change notification settings - Fork 976
[megatron] support megatron MTP #6496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @Jintao-Huang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Megatron framework by introducing support for Multi-Token Prediction (MTP) within the GPT model. The changes involve adding new configurable parameters for MTP layers and loss scaling, alongside a substantial refactoring of the GPT model's forward pass. This refactoring separates input preprocessing and output postprocessing into distinct methods, creating a more modular and maintainable architecture that seamlessly incorporates MTP as an additional training objective. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for Multi-Token Prediction (MTP) in Megatron, including new command-line arguments and documentation. The core change is a significant refactoring of the GPTModel.forward method to accommodate the MTP logic. My review focuses on the correctness of this refactoring and the consistency of the documentation. I've identified a critical breaking API change in the forward method signature that needs to be addressed. Additionally, there are some language inconsistencies in the documentation files that should be corrected for clarity.
swift/megatron/model/gpt_model.py
Outdated
| self, | ||
| input_ids: torch.Tensor, | ||
| position_ids: torch.Tensor, | ||
| attention_mask: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The signature of the forward method has been changed, making attention_mask a required positional argument. This is a breaking API change that will cause errors in other parts of the codebase that call this method without providing attention_mask or passing it as a keyword argument. For example, MultimodalGPTModel.forward in swift/megatron/model/mm_gpt_model.py will break. To maintain backward compatibility, attention_mask should remain an optional keyword argument.
| attention_mask: torch.Tensor, | |
| attention_mask: torch.Tensor = None, |
| - mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None. | ||
| - mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The descriptions for the new MTP parameters are in English, but this documentation file is in Chinese. For consistency, please translate these descriptions into Chinese.
| - mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None. | |
| - mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1. | |
| - mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。 | |
| - mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。 |
| - mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。 | ||
| - mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The descriptions for the new MTP parameters are in Chinese, but this documentation file is in English. For consistency, please translate these descriptions into English.
| - mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。 | |
| - mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。 | |
| - mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None. | |
| - mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1. |
No description provided.