Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/Megatron-SWIFT/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@
- qk_head_dim: QK 投影中 head 的维度。 `q_head_dim = qk_head_dim + qk_pos_emb_head_dim`。默认为None,自动从config.json读取。
- qk_pos_emb_head_dim: QK 投影中位置嵌入的维度。默认为None,自动从config.json读取。

**MTP参数**
- mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None.
- mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The descriptions for the new MTP parameters are in English, but this documentation file is in Chinese. For consistency, please translate these descriptions into Chinese.

Suggested change
- mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None.
- mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1.
- mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。
- mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。



**Tuner参数**:
- train_type: 可选为'lora'和'full'。默认为'full'。
- 🔥freeze_llm: 该参数只对多模态模型生效,可用于全参数训练和LoRA训练,但会产生不同的效果。若是全参数训练,将freeze_llm设置为True会将LLM部分权重进行冻结;若是LoRA训练且`target_modules`设置为'all-linear',将freeze_llm设置为True将会取消在LLM部分添加LoRA模块。该参数默认为False。
Expand Down
5 changes: 5 additions & 0 deletions docs/source_en/Megatron-SWIFT/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@ For guidance on selecting parallelization strategies, please refer to the [Train
- qk_head_dim: Dimension of the head in the QK projection. `q_head_dim = qk_head_dim + qk_pos_emb_head_dim`. Default is None and will be automatically read from config.json.
- qk_pos_emb_head_dim: Dimension of the position embedding in the QK projection. Default is None and will be automatically read from config.json.


**MTP Parameters**
- mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。
- mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The descriptions for the new MTP parameters are in Chinese, but this documentation file is in English. For consistency, please translate these descriptions into English.

Suggested change
- mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。
- mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。
- mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None.
- mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1.


**Tuner Parameters**:

- train_type: Options are `'lora'` and `'full'`. Default is `'full'`.
Expand Down
4 changes: 4 additions & 0 deletions swift/megatron/argument/megatron_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,10 @@ class MegatronArguments(ExtraMegatronArguments):
qk_head_dim: Optional[int] = None
qk_pos_emb_head_dim: Optional[int] = None

# mtp
mtp_num_layers: Optional[int] = None
mtp_loss_scaling_factor: float = 0.1

# fp8
fp8_format: Literal['e4m3', 'hybrid'] = None
fp8_recipe: Literal['tensorwise', 'delayed', 'mxfp8', 'blockwise'] = 'delayed'
Expand Down
187 changes: 154 additions & 33 deletions swift/megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
from typing import Any, Dict, Literal, Optional, Tuple

import torch
from megatron.core import InferenceParams
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.extensions.transformer_engine import TELinear
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.gpt import GPTModel as McoreGPTModel
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import WrappedTensor, deprecate_inference_params
from megatron.training import get_args

from swift.utils import get_logger
Expand Down Expand Up @@ -145,30 +146,20 @@ def apply_rotary_pos_emb(*args, **kwargs):
finally:
attention.apply_rotary_pos_emb = origin_apply_rotary_pos_emb

# Code borrowed from NVIDIA/Megatron-LM
def forward(
def _preprocess(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor = None,
decoder_input: torch.Tensor = None,
labels: torch.Tensor = None,
inference_params: InferenceParams = None,
inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
**kwargs,
) -> torch.Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).

It either returns the Loss values if labels are given or the final hidden units
):
"""Preprocesses inputs for the transformer decoder.

Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
Applies embeddings to input tokens, or uses `decoder_input` from a previous
pipeline stage. Also sets up rotary positional embeddings.
"""

# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.

Expand All @@ -185,20 +176,23 @@ def forward(
if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad:
# fix LoRA incompatibility with gradient checkpointing
decoder_input = decoder_input.requires_grad_(True)

# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
rotary_pos_cos = None
rotary_pos_sin = None
if self.position_embedding_type in {'rope', 'mrope'}:
if not self.training and self.config.flash_decode and inference_params:
if not self.training and self.config.flash_decode and inference_context:
assert (inference_context.is_static_batching()
), 'GPTModel currently only supports static inference batching.'
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(
inference_params.max_sequence_length,
self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length),
inference_context.max_sequence_length,
self.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length),
)
else:
rotary_seq_len = RotaryEmbedding.get_rotary_seq_len(self, inference_params, self.decoder, decoder_input,
self.config, packed_seq_params)
rotary_seq_len = RotaryEmbedding.get_rotary_seq_len(self, inference_context, self.decoder,
decoder_input, self.config, packed_seq_params)
if self.hf_rope_scaling is not None:
attention_scaling = dynamic_rope_update(self, self.rotary_pos_emb.inv_freq, rotary_seq_len)
if attention_scaling is not None:
Expand All @@ -214,38 +208,92 @@ def forward(
rotary_seq_len,
packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd',
)

if ((self.config.enable_cuda_graph or self.config.flash_decode) and rotary_pos_cos is not None
and inference_params):
and inference_context and inference_context.is_static_batching() and not self.training):
current_batch_size = input_ids.shape[0]
sequence_len_offset = torch.tensor(
[inference_params.sequence_len_offset] * inference_params.current_batch_size,
[inference_context.sequence_len_offset] * current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
else:
sequence_len_offset = None

# Run decoder.
with self._patch_apply_rotary_pos_emb():
hidden_states = self.decoder(
hidden_states=decoder_input,
# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if (inference_context is not None and not self.training and not has_config_logger_enabled(self.config)):
decoder_input = WrappedTensor(decoder_input)

return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset

def _postprocess(
self,
hidden_states,
input_ids,
position_ids,
labels,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
mtp_in_postprocess=None,
loss_mask=None,
decoder_input=None,
attention_mask=None,
inference_params=None,
packed_seq_params=None,
sequence_len_offset=None,
runtime_gather_output=None,
extra_block_kwargs=None,
inference_context=None,
):
"""Postprocesses decoder hidden states to generate logits or compute loss.

Applies Multi-Token Prediction if enabled, generates output logits through
the output layer, and computes language model loss when labels are provided.
"""
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()

if mtp_in_postprocess:
hidden_states = self.mtp(
input_ids=input_ids,
position_ids=position_ids,
labels=labels,
loss_mask=loss_mask,
hidden_states=hidden_states,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
embedding=self.embedding,
output_layer=self.output_layer,
output_weight=output_weight,
runtime_gather_output=runtime_gather_output,
compute_language_model_loss=self.compute_language_model_loss,
**(extra_block_kwargs or {}),
**kwargs,
)

if not self.post_process:
return hidden_states

if (not self.training and inference_context is not None
and inference_context.materialize_only_last_token_logits):
if inference_context.is_static_batching():
hidden_states = hidden_states[-1:, :, :]
else:
# Reshape [B, 1, H] to [1, B, H] → extract each sample’s true last‐token hidden
# state ([B, H]) → unsqueeze back to [1, B, H]
# (so that the output layer, which expects S×B×H, receives only the final token)
hidden_states = inference_context.last_token_logits(hidden_states.squeeze(1).unsqueeze(0)).unsqueeze(1)

# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
args = get_args()
if args.task_type == 'causal_lm':
logits, _ = self.output_layer(
Expand All @@ -271,5 +319,78 @@ def forward(

return loss

# Code borrowed from NVIDIA/Megatron-LM
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The signature of the forward method has been changed, making attention_mask a required positional argument. This is a breaking API change that will cause errors in other parts of the codebase that call this method without providing attention_mask or passing it as a keyword argument. For example, MultimodalGPTModel.forward in swift/megatron/model/mm_gpt_model.py will break. To maintain backward compatibility, attention_mask should remain an optional keyword argument.

Suggested change
attention_mask: torch.Tensor,
attention_mask: torch.Tensor = None,

decoder_input: torch.Tensor = None,
labels: torch.Tensor = None,
inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).

It either returns the Loss values if labels are given or the final hidden units

Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""

inference_context = deprecate_inference_params(inference_context, inference_params)

decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = (
self._preprocess(
input_ids=input_ids,
position_ids=position_ids,
decoder_input=decoder_input,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
))
# Run decoder.
with self._patch_apply_rotary_pos_emb():
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
**(extra_block_kwargs or {}),
**kwargs,
)

return self._postprocess(
hidden_states=hidden_states,
input_ids=input_ids,
position_ids=position_ids,
labels=labels,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
mtp_in_postprocess=self.mtp_process,
loss_mask=loss_mask,
decoder_input=decoder_input,
attention_mask=attention_mask,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
runtime_gather_output=runtime_gather_output,
extra_block_kwargs=extra_block_kwargs,
inference_context=inference_context,
)

def get_input_tensor(self):
return self.decoder.input_tensor
Loading