-
Notifications
You must be signed in to change notification settings - Fork 1k
[megatron] support megatron MTP #6496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
506ae6a
d918efe
0bd955a
6c29553
6738b66
a901db5
b460d28
cfa8057
cd9abff
2bfeed5
4d60be4
8d4fe50
ed5646c
31a21c5
08832f3
86b6979
351be95
3cb7862
9c62143
d1e7ce8
65d0c66
3220a8f
89e6627
0df0c2f
78bad41
a2ef030
85e8429
e67237f
b455538
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -227,6 +227,11 @@ For guidance on selecting parallelization strategies, please refer to the [Train | |||||||||
| - qk_head_dim: Dimension of the head in the QK projection. `q_head_dim = qk_head_dim + qk_pos_emb_head_dim`. Default is None and will be automatically read from config.json. | ||||||||||
| - qk_pos_emb_head_dim: Dimension of the position embedding in the QK projection. Default is None and will be automatically read from config.json. | ||||||||||
|
|
||||||||||
|
|
||||||||||
| **MTP Parameters** | ||||||||||
| - mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。 | ||||||||||
| - mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。 | ||||||||||
|
||||||||||
| - mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。 | |
| - mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。 | |
| - mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None. | |
| - mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1. |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -4,16 +4,17 @@ | |||||
| from typing import Any, Dict, Literal, Optional, Tuple | ||||||
|
|
||||||
| import torch | ||||||
| from megatron.core import InferenceParams | ||||||
| from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk | ||||||
| from megatron.core.dist_checkpointing.mapping import ShardedStateDict | ||||||
| from megatron.core.extensions.transformer_engine import TELinear | ||||||
| from megatron.core.inference.contexts import BaseInferenceContext | ||||||
| from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding | ||||||
| from megatron.core.models.gpt import GPTModel as McoreGPTModel | ||||||
| from megatron.core.packed_seq_params import PackedSeqParams | ||||||
| from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region | ||||||
| from megatron.core.transformer.spec_utils import ModuleSpec | ||||||
| from megatron.core.transformer.transformer_config import TransformerConfig | ||||||
| from megatron.core.utils import WrappedTensor, deprecate_inference_params | ||||||
| from megatron.training import get_args | ||||||
|
|
||||||
| from swift.utils import get_logger | ||||||
|
|
@@ -145,30 +146,20 @@ def apply_rotary_pos_emb(*args, **kwargs): | |||||
| finally: | ||||||
| attention.apply_rotary_pos_emb = origin_apply_rotary_pos_emb | ||||||
|
|
||||||
| # Code borrowed from NVIDIA/Megatron-LM | ||||||
| def forward( | ||||||
| def _preprocess( | ||||||
| self, | ||||||
| input_ids: torch.Tensor, | ||||||
| position_ids: torch.Tensor, | ||||||
| attention_mask: torch.Tensor = None, | ||||||
| decoder_input: torch.Tensor = None, | ||||||
| labels: torch.Tensor = None, | ||||||
| inference_params: InferenceParams = None, | ||||||
| inference_context: BaseInferenceContext = None, | ||||||
| packed_seq_params: PackedSeqParams = None, | ||||||
| extra_block_kwargs: dict = None, | ||||||
| runtime_gather_output: Optional[bool] = None, | ||||||
| **kwargs, | ||||||
| ) -> torch.Tensor: | ||||||
| """Forward function of the GPT Model This function passes the input tensors | ||||||
| through the embedding layer, and then the decoeder and finally into the post | ||||||
| processing layer (optional). | ||||||
|
|
||||||
| It either returns the Loss values if labels are given or the final hidden units | ||||||
| ): | ||||||
| """Preprocesses inputs for the transformer decoder. | ||||||
|
|
||||||
| Args: | ||||||
| runtime_gather_output (bool): Gather output at runtime. Default None means | ||||||
| `parallel_output` arg in the constructor will be used. | ||||||
| Applies embeddings to input tokens, or uses `decoder_input` from a previous | ||||||
| pipeline stage. Also sets up rotary positional embeddings. | ||||||
| """ | ||||||
|
|
||||||
| # If decoder_input is provided (not None), then input_ids and position_ids are ignored. | ||||||
| # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. | ||||||
|
|
||||||
|
|
@@ -185,20 +176,23 @@ def forward( | |||||
| if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad: | ||||||
| # fix LoRA incompatibility with gradient checkpointing | ||||||
| decoder_input = decoder_input.requires_grad_(True) | ||||||
|
|
||||||
| # Rotary positional embeddings (embedding is None for PP intermediate devices) | ||||||
| rotary_pos_emb = None | ||||||
| rotary_pos_cos = None | ||||||
| rotary_pos_sin = None | ||||||
| if self.position_embedding_type in {'rope', 'mrope'}: | ||||||
| if not self.training and self.config.flash_decode and inference_params: | ||||||
| if not self.training and self.config.flash_decode and inference_context: | ||||||
| assert (inference_context.is_static_batching() | ||||||
| ), 'GPTModel currently only supports static inference batching.' | ||||||
| # Flash decoding uses precomputed cos and sin for RoPE | ||||||
| rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault( | ||||||
| inference_params.max_sequence_length, | ||||||
| self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length), | ||||||
| inference_context.max_sequence_length, | ||||||
| self.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length), | ||||||
| ) | ||||||
| else: | ||||||
| rotary_seq_len = RotaryEmbedding.get_rotary_seq_len(self, inference_params, self.decoder, decoder_input, | ||||||
| self.config, packed_seq_params) | ||||||
| rotary_seq_len = RotaryEmbedding.get_rotary_seq_len(self, inference_context, self.decoder, | ||||||
| decoder_input, self.config, packed_seq_params) | ||||||
| if self.hf_rope_scaling is not None: | ||||||
| attention_scaling = dynamic_rope_update(self, self.rotary_pos_emb.inv_freq, rotary_seq_len) | ||||||
| if attention_scaling is not None: | ||||||
|
|
@@ -214,38 +208,92 @@ def forward( | |||||
| rotary_seq_len, | ||||||
| packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd', | ||||||
| ) | ||||||
|
|
||||||
| if ((self.config.enable_cuda_graph or self.config.flash_decode) and rotary_pos_cos is not None | ||||||
| and inference_params): | ||||||
| and inference_context and inference_context.is_static_batching() and not self.training): | ||||||
| current_batch_size = input_ids.shape[0] | ||||||
| sequence_len_offset = torch.tensor( | ||||||
| [inference_params.sequence_len_offset] * inference_params.current_batch_size, | ||||||
| [inference_context.sequence_len_offset] * current_batch_size, | ||||||
| dtype=torch.int32, | ||||||
| device=rotary_pos_cos.device, # Co-locate this with the rotary tensors | ||||||
| ) | ||||||
| else: | ||||||
| sequence_len_offset = None | ||||||
|
|
||||||
| # Run decoder. | ||||||
| with self._patch_apply_rotary_pos_emb(): | ||||||
| hidden_states = self.decoder( | ||||||
| hidden_states=decoder_input, | ||||||
| # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the | ||||||
| # reference held by this caller function, enabling early garbage collection for | ||||||
| # inference. Skip wrapping if decoder_input is logged after decoder completion. | ||||||
| if (inference_context is not None and not self.training and not has_config_logger_enabled(self.config)): | ||||||
| decoder_input = WrappedTensor(decoder_input) | ||||||
|
|
||||||
| return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset | ||||||
|
|
||||||
| def _postprocess( | ||||||
| self, | ||||||
| hidden_states, | ||||||
| input_ids, | ||||||
| position_ids, | ||||||
| labels, | ||||||
| rotary_pos_emb, | ||||||
| rotary_pos_cos, | ||||||
| rotary_pos_sin, | ||||||
| mtp_in_postprocess=None, | ||||||
| loss_mask=None, | ||||||
| decoder_input=None, | ||||||
| attention_mask=None, | ||||||
| inference_params=None, | ||||||
| packed_seq_params=None, | ||||||
| sequence_len_offset=None, | ||||||
| runtime_gather_output=None, | ||||||
| extra_block_kwargs=None, | ||||||
| inference_context=None, | ||||||
| ): | ||||||
| """Postprocesses decoder hidden states to generate logits or compute loss. | ||||||
|
|
||||||
| Applies Multi-Token Prediction if enabled, generates output logits through | ||||||
| the output layer, and computes language model loss when labels are provided. | ||||||
| """ | ||||||
| # logits and loss | ||||||
| output_weight = None | ||||||
| if self.share_embeddings_and_output_weights: | ||||||
| output_weight = self.shared_embedding_or_output_weight() | ||||||
|
|
||||||
| if mtp_in_postprocess: | ||||||
| hidden_states = self.mtp( | ||||||
| input_ids=input_ids, | ||||||
| position_ids=position_ids, | ||||||
| labels=labels, | ||||||
| loss_mask=loss_mask, | ||||||
| hidden_states=hidden_states, | ||||||
| attention_mask=attention_mask, | ||||||
| inference_params=inference_params, | ||||||
| rotary_pos_emb=rotary_pos_emb, | ||||||
| rotary_pos_cos=rotary_pos_cos, | ||||||
| rotary_pos_sin=rotary_pos_sin, | ||||||
| packed_seq_params=packed_seq_params, | ||||||
| sequence_len_offset=sequence_len_offset, | ||||||
| embedding=self.embedding, | ||||||
| output_layer=self.output_layer, | ||||||
| output_weight=output_weight, | ||||||
| runtime_gather_output=runtime_gather_output, | ||||||
| compute_language_model_loss=self.compute_language_model_loss, | ||||||
| **(extra_block_kwargs or {}), | ||||||
| **kwargs, | ||||||
| ) | ||||||
|
|
||||||
| if not self.post_process: | ||||||
| return hidden_states | ||||||
|
|
||||||
| if (not self.training and inference_context is not None | ||||||
| and inference_context.materialize_only_last_token_logits): | ||||||
| if inference_context.is_static_batching(): | ||||||
| hidden_states = hidden_states[-1:, :, :] | ||||||
| else: | ||||||
| # Reshape [B, 1, H] to [1, B, H] → extract each sample’s true last‐token hidden | ||||||
| # state ([B, H]) → unsqueeze back to [1, B, H] | ||||||
| # (so that the output layer, which expects S×B×H, receives only the final token) | ||||||
| hidden_states = inference_context.last_token_logits(hidden_states.squeeze(1).unsqueeze(0)).unsqueeze(1) | ||||||
|
|
||||||
| # logits and loss | ||||||
| output_weight = None | ||||||
| if self.share_embeddings_and_output_weights: | ||||||
| output_weight = self.shared_embedding_or_output_weight() | ||||||
| args = get_args() | ||||||
| if args.task_type == 'causal_lm': | ||||||
| logits, _ = self.output_layer( | ||||||
|
|
@@ -271,5 +319,78 @@ def forward( | |||||
|
|
||||||
| return loss | ||||||
|
|
||||||
| # Code borrowed from NVIDIA/Megatron-LM | ||||||
| def forward( | ||||||
| self, | ||||||
| input_ids: torch.Tensor, | ||||||
| position_ids: torch.Tensor, | ||||||
| attention_mask: torch.Tensor, | ||||||
|
||||||
| attention_mask: torch.Tensor, | |
| attention_mask: torch.Tensor = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The descriptions for the new MTP parameters are in English, but this documentation file is in Chinese. For consistency, please translate these descriptions into Chinese.