From 506ae6a713dc955eabd7eab63f8115ec1c0418ef Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 8 Nov 2025 15:51:49 +0800 Subject: [PATCH 1/8] support megatron MTP --- .../Megatron-SWIFT/Command-line-parameters.md | 5 + .../Megatron-SWIFT/Command-line-parameters.md | 5 + swift/megatron/argument/megatron_args.py | 4 + swift/megatron/model/gpt_model.py | 187 ++++++++++++++---- 4 files changed, 168 insertions(+), 33 deletions(-) diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index 9a1c90658c..fea09478e7 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -214,6 +214,11 @@ - qk_head_dim: QK 投影中 head 的维度。 `q_head_dim = qk_head_dim + qk_pos_emb_head_dim`。默认为None,自动从config.json读取。 - qk_pos_emb_head_dim: QK 投影中位置嵌入的维度。默认为None,自动从config.json读取。 +**MTP参数** +- mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None. +- mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1. + + **Tuner参数**: - train_type: 可选为'lora'和'full'。默认为'full'。 - 🔥freeze_llm: 该参数只对多模态模型生效,可用于全参数训练和LoRA训练,但会产生不同的效果。若是全参数训练,将freeze_llm设置为True会将LLM部分权重进行冻结;若是LoRA训练且`target_modules`设置为'all-linear',将freeze_llm设置为True将会取消在LLM部分添加LoRA模块。该参数默认为False。 diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 7062d13d23..ba7067adf8 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -227,6 +227,11 @@ For guidance on selecting parallelization strategies, please refer to the [Train - qk_head_dim: Dimension of the head in the QK projection. `q_head_dim = qk_head_dim + qk_pos_emb_head_dim`. Default is None and will be automatically read from config.json. - qk_pos_emb_head_dim: Dimension of the position embedding in the QK projection. Default is None and will be automatically read from config.json. + +**MTP Parameters** +- mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。 +- mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。 + **Tuner Parameters**: - train_type: Options are `'lora'` and `'full'`. Default is `'full'`. diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index d2b3c96217..f6611b02b9 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -313,6 +313,10 @@ class MegatronArguments(ExtraMegatronArguments): qk_head_dim: Optional[int] = None qk_pos_emb_head_dim: Optional[int] = None + # mtp + mtp_num_layers: Optional[int] = None + mtp_loss_scaling_factor: float = 0.1 + # fp8 fp8_format: Literal['e4m3', 'hybrid'] = None fp8_recipe: Literal['tensorwise', 'delayed', 'mxfp8', 'blockwise'] = 'delayed' diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index b3aa2b4f8d..ea07cdbba4 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -4,16 +4,17 @@ from typing import Any, Dict, Literal, Optional, Tuple import torch -from megatron.core import InferenceParams from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.extensions.transformer_engine import TELinear +from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.gpt import GPTModel as McoreGPTModel from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import WrappedTensor, deprecate_inference_params from megatron.training import get_args from swift.utils import get_logger @@ -145,30 +146,20 @@ def apply_rotary_pos_emb(*args, **kwargs): finally: attention.apply_rotary_pos_emb = origin_apply_rotary_pos_emb - # Code borrowed from NVIDIA/Megatron-LM - def forward( + def _preprocess( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - attention_mask: torch.Tensor = None, decoder_input: torch.Tensor = None, - labels: torch.Tensor = None, - inference_params: InferenceParams = None, + inference_context: BaseInferenceContext = None, packed_seq_params: PackedSeqParams = None, - extra_block_kwargs: dict = None, - runtime_gather_output: Optional[bool] = None, - **kwargs, - ) -> torch.Tensor: - """Forward function of the GPT Model This function passes the input tensors - through the embedding layer, and then the decoeder and finally into the post - processing layer (optional). - - It either returns the Loss values if labels are given or the final hidden units + ): + """Preprocesses inputs for the transformer decoder. - Args: - runtime_gather_output (bool): Gather output at runtime. Default None means - `parallel_output` arg in the constructor will be used. + Applies embeddings to input tokens, or uses `decoder_input` from a previous + pipeline stage. Also sets up rotary positional embeddings. """ + # If decoder_input is provided (not None), then input_ids and position_ids are ignored. # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. @@ -185,20 +176,23 @@ def forward( if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad: # fix LoRA incompatibility with gradient checkpointing decoder_input = decoder_input.requires_grad_(True) + # Rotary positional embeddings (embedding is None for PP intermediate devices) rotary_pos_emb = None rotary_pos_cos = None rotary_pos_sin = None if self.position_embedding_type in {'rope', 'mrope'}: - if not self.training and self.config.flash_decode and inference_params: + if not self.training and self.config.flash_decode and inference_context: + assert (inference_context.is_static_batching() + ), 'GPTModel currently only supports static inference batching.' # Flash decoding uses precomputed cos and sin for RoPE rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault( - inference_params.max_sequence_length, - self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length), + inference_context.max_sequence_length, + self.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length), ) else: - rotary_seq_len = RotaryEmbedding.get_rotary_seq_len(self, inference_params, self.decoder, decoder_input, - self.config, packed_seq_params) + rotary_seq_len = RotaryEmbedding.get_rotary_seq_len(self, inference_context, self.decoder, + decoder_input, self.config, packed_seq_params) if self.hf_rope_scaling is not None: attention_scaling = dynamic_rope_update(self, self.rotary_pos_emb.inv_freq, rotary_seq_len) if attention_scaling is not None: @@ -214,20 +208,63 @@ def forward( rotary_seq_len, packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd', ) + if ((self.config.enable_cuda_graph or self.config.flash_decode) and rotary_pos_cos is not None - and inference_params): + and inference_context and inference_context.is_static_batching() and not self.training): + current_batch_size = input_ids.shape[0] sequence_len_offset = torch.tensor( - [inference_params.sequence_len_offset] * inference_params.current_batch_size, + [inference_context.sequence_len_offset] * current_batch_size, dtype=torch.int32, device=rotary_pos_cos.device, # Co-locate this with the rotary tensors ) else: sequence_len_offset = None - # Run decoder. - with self._patch_apply_rotary_pos_emb(): - hidden_states = self.decoder( - hidden_states=decoder_input, + # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the + # reference held by this caller function, enabling early garbage collection for + # inference. Skip wrapping if decoder_input is logged after decoder completion. + if (inference_context is not None and not self.training and not has_config_logger_enabled(self.config)): + decoder_input = WrappedTensor(decoder_input) + + return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset + + def _postprocess( + self, + hidden_states, + input_ids, + position_ids, + labels, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + mtp_in_postprocess=None, + loss_mask=None, + decoder_input=None, + attention_mask=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, + ): + """Postprocesses decoder hidden states to generate logits or compute loss. + + Applies Multi-Token Prediction if enabled, generates output logits through + the output layer, and computes language model loss when labels are provided. + """ + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + + if mtp_in_postprocess: + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + loss_mask=loss_mask, + hidden_states=hidden_states, attention_mask=attention_mask, inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, @@ -235,17 +272,28 @@ def forward( rotary_pos_sin=rotary_pos_sin, packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, + embedding=self.embedding, + output_layer=self.output_layer, + output_weight=output_weight, + runtime_gather_output=runtime_gather_output, + compute_language_model_loss=self.compute_language_model_loss, **(extra_block_kwargs or {}), - **kwargs, ) if not self.post_process: return hidden_states + if (not self.training and inference_context is not None + and inference_context.materialize_only_last_token_logits): + if inference_context.is_static_batching(): + hidden_states = hidden_states[-1:, :, :] + else: + # Reshape [B, 1, H] to [1, B, H] → extract each sample’s true last‐token hidden + # state ([B, H]) → unsqueeze back to [1, B, H] + # (so that the output layer, which expects S×B×H, receives only the final token) + hidden_states = inference_context.last_token_logits(hidden_states.squeeze(1).unsqueeze(0)).unsqueeze(1) + # logits and loss - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() args = get_args() if args.task_type == 'causal_lm': logits, _ = self.output_layer( @@ -271,5 +319,78 @@ def forward( return loss + # Code borrowed from NVIDIA/Megatron-LM + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + decoder_input: torch.Tensor = None, + labels: torch.Tensor = None, + inference_context: BaseInferenceContext = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + runtime_gather_output: Optional[bool] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Forward function of the GPT Model This function passes the input tensors + through the embedding layer, and then the decoeder and finally into the post + processing layer (optional). + + It either returns the Loss values if labels are given or the final hidden units + + Args: + runtime_gather_output (bool): Gather output at runtime. Default None means + `parallel_output` arg in the constructor will be used. + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + + decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = ( + self._preprocess( + input_ids=input_ids, + position_ids=position_ids, + decoder_input=decoder_input, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + )) + # Run decoder. + with self._patch_apply_rotary_pos_emb(): + hidden_states = self.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + **(extra_block_kwargs or {}), + **kwargs, + ) + + return self._postprocess( + hidden_states=hidden_states, + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + mtp_in_postprocess=self.mtp_process, + loss_mask=loss_mask, + decoder_input=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, + ) + def get_input_tensor(self): return self.decoder.input_tensor From 6738b66561a0ee177c384c0187438404005e57f6 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 10 Nov 2025 21:59:15 +0800 Subject: [PATCH 2/8] update --- swift/megatron/model/gpt_model.py | 79 +------------------------------ 1 file changed, 1 insertion(+), 78 deletions(-) diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 9bc438413d..0aaa563277 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -288,12 +288,8 @@ def forward( rotary_pos_sin=rotary_pos_sin, packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, - embedding=self.embedding, - output_layer=self.output_layer, - output_weight=output_weight, - runtime_gather_output=runtime_gather_output, - compute_language_model_loss=self.compute_language_model_loss, **(extra_block_kwargs or {}), + **kwargs, ) args = get_args() @@ -317,78 +313,5 @@ def forward( inference_context=inference_context, ) - # Code borrowed from NVIDIA/Megatron-LM - def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - attention_mask: torch.Tensor, - decoder_input: torch.Tensor = None, - labels: torch.Tensor = None, - inference_context: BaseInferenceContext = None, - packed_seq_params: PackedSeqParams = None, - extra_block_kwargs: dict = None, - runtime_gather_output: Optional[bool] = None, - *, - inference_params: Optional[BaseInferenceContext] = None, - loss_mask: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """Forward function of the GPT Model This function passes the input tensors - through the embedding layer, and then the decoeder and finally into the post - processing layer (optional). - - It either returns the Loss values if labels are given or the final hidden units - - Args: - runtime_gather_output (bool): Gather output at runtime. Default None means - `parallel_output` arg in the constructor will be used. - """ - - inference_context = deprecate_inference_params(inference_context, inference_params) - - decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = ( - self._preprocess( - input_ids=input_ids, - position_ids=position_ids, - decoder_input=decoder_input, - inference_context=inference_context, - packed_seq_params=packed_seq_params, - )) - # Run decoder. - with self._patch_apply_rotary_pos_emb(): - hidden_states = self.decoder( - hidden_states=decoder_input, - attention_mask=attention_mask, - inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - **(extra_block_kwargs or {}), - **kwargs, - ) - - return self._postprocess( - hidden_states=hidden_states, - input_ids=input_ids, - position_ids=position_ids, - labels=labels, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - mtp_in_postprocess=self.mtp_process, - loss_mask=loss_mask, - decoder_input=decoder_input, - attention_mask=attention_mask, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - runtime_gather_output=runtime_gather_output, - extra_block_kwargs=extra_block_kwargs, - inference_context=inference_context, - ) - def get_input_tensor(self): return self.decoder.input_tensor From b460d2847f4ef7464eeaf0fb6189c58a9c755a4c Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 18 Nov 2025 16:42:25 +0800 Subject: [PATCH 3/8] update --- .../Megatron-SWIFT/Command-line-parameters.md | 5 +-- .../Megatron-SWIFT/Command-line-parameters.md | 4 +- swift/megatron/model/gpt_bridge.py | 44 ++++++++++++++++++- swift/megatron/trainers/base.py | 4 +- 4 files changed, 50 insertions(+), 7 deletions(-) diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index 25ddd68fbc..28e6c0eeca 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -219,9 +219,8 @@ - qk_pos_emb_head_dim: QK 投影中位置嵌入的维度。默认为None,自动从config.json读取。 **MTP参数** -- mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None. -- mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1. - +- mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。 +- mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。 **Tuner参数**: - train_type: 可选为'lora'和'full'。默认为'full'。 diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 84bc68050b..2a8674b0ee 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -233,8 +233,8 @@ For guidance on selecting parallelization strategies, please refer to the [Train **MTP Parameters** -- mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。 -- mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。 +- mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None. +- mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1. **Tuner Parameters**: diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 40aff04568..48c5f8e2ab 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -26,6 +26,7 @@ # Some ideas for LoRA conversion are referenced from: https://github.com/modelscope/ms-swift/pull/6225 class GPTBridge: hf_layers_prefix = 'model.layers' + hf_mtp_prefix = 'model.layers' hf_embed_key = 'model.embed_tokens.weight' hf_final_layernorm_key = 'model.norm.weight' hf_lm_head_key = 'lm_head.weight' @@ -79,7 +80,9 @@ def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]: # mla 'linear_q_proj', 'linear_q_up_proj', - 'linear_kv_up_proj' + 'linear_kv_up_proj', + # mtp + 'eh_proj', } if self.args.task_type == 'causal_lm': dim0_keys.add('output_layer') @@ -1018,6 +1021,23 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd else: yield from list(self._add_prefix(res, hf_prefix).items()) hf_state_dict = {} + + if not to_mcore or is_pp_last_stage and self.args.mtp_num_layers: + layer_idx = 0 + while layer_idx < self.args.mtp_num_layers: + mtp_layer = mg_model.mtp.layers[layer_idx] if hasattr(mg_model, 'mtp') else None + if self.hf_mtp_prefix == self.hf_layers_prefix: + hf_layer_idx = layer_idx + self.args.num_layers + else: + hf_layer_idx = layer_idx + res = self._convert_mtp_layer(mtp_layer, hf_state_dict, f'{self.hf_mtp_prefix}.', hf_layer_idx, + to_mcore) + layer_idx += 1 + if to_mcore: + yield + else: + yield from list(self._add_prefix(res, hf_prefix).items()) + hf_state_dict = {} if not to_mcore or is_pp_last_stage: hf_state_dict.update(self._convert_post_process(mg_model, hf_state_dict, '', to_mcore)) if to_mcore: @@ -1026,6 +1046,28 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd hf_state_dict = self._convert_hf_state_dict(hf_state_dict, to_mcore) yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) + def _convert_mtp_layer(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): + hf_prefix = f'{hf_prefix}{layer_idx}.' + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} + if not to_mcore: + # TODO: 'embed_tokens.weight', 'shared_head.head.weight' + pass + for key in ['enorm.weight', 'hnorm.weight', 'eh_proj.weight']: + self._set_state_dict(mg_layer, key, hf_state_dict, key, to_mcore) + if layer_idx >= len(self.hf_layers): + layer_idx = -1 + hf_state_dict.update(self._set_layer_attn(mg_layer.transformer_layer, hf_state_dict, layer_idx, to_mcore)) + hf_state_dict.update(self._set_layer_mlp(mg_layer.transformer_layer, hf_state_dict, layer_idx, to_mcore)) + self._set_state_dict(mg_layer, 'final_layernorm.weight', hf_state_dict, 'shared_head.norm.weight', to_mcore) + if to_mcore: + hf_state_dict = {} + else: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict + def load_weights(self, mg_model, hf_model_dir: str, is_peft_format: bool = False, adapter_name: str = 'default'): self._is_peft_format = is_peft_format self._adapter_name = adapter_name diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 999745acbf..4eebeb2eeb 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -785,6 +785,7 @@ def training_log(self, loss_dict, total_loss_dict, learning_rate, decoupled_lear track_names.append('load_balancing_loss') if args.moe_z_loss_coeff is not None: track_names.append('z_loss') + track_moe_kwargs = {'mtp_num_layers': args.mtp_num_layers} if self.mcore_013 else {} track_moe_metrics( loss_scale=moe_loss_scale, iteration=iteration, @@ -795,7 +796,8 @@ def training_log(self, loss_dict, total_loss_dict, learning_rate, decoupled_lear force_initialize=True, track_names=track_names, num_layers=args.num_layers, - moe_layer_freq=args.moe_layer_freq) + moe_layer_freq=args.moe_layer_freq, + **track_moe_kwargs) if args.mtp_num_layers is not None: mtp_loss_scale = 1 / get_num_microbatches() MTPLossLoggingHelper.track_mtp_metrics(mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict) From cfa8057adcec52bb825685594b4a8f8d756966fe Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 19 Nov 2025 10:56:27 +0800 Subject: [PATCH 4/8] update --- swift/megatron/model/gpt_bridge.py | 30 +++++++++++++++--------------- swift/megatron/model/gpt_model.py | 1 + swift/megatron/trainers/utils.py | 2 +- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 48c5f8e2ab..b789268c91 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -1023,14 +1023,10 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd hf_state_dict = {} if not to_mcore or is_pp_last_stage and self.args.mtp_num_layers: + lm_model = getattr(mg_model, 'language_model') if self.args.is_multimodal else mg_model layer_idx = 0 while layer_idx < self.args.mtp_num_layers: - mtp_layer = mg_model.mtp.layers[layer_idx] if hasattr(mg_model, 'mtp') else None - if self.hf_mtp_prefix == self.hf_layers_prefix: - hf_layer_idx = layer_idx + self.args.num_layers - else: - hf_layer_idx = layer_idx - res = self._convert_mtp_layer(mtp_layer, hf_state_dict, f'{self.hf_mtp_prefix}.', hf_layer_idx, + res = self._convert_mtp_layer(lm_model, hf_state_dict, f'{self.hf_mtp_prefix}.', layer_idx, to_mcore) layer_idx += 1 if to_mcore: @@ -1046,21 +1042,25 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd hf_state_dict = self._convert_hf_state_dict(hf_state_dict, to_mcore) yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) - def _convert_mtp_layer(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): - hf_prefix = f'{hf_prefix}{layer_idx}.' + def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): + mtp_layer = lm_model.mtp.layers[layer_idx] if hasattr(lm_model, 'mtp') else None + if self.hf_mtp_prefix == self.hf_layers_prefix: + hf_layer_idx = layer_idx + self.args.num_layers + else: + hf_layer_idx = layer_idx + hf_prefix = f'{hf_prefix}{hf_layer_idx}.' if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) else: hf_state_dict = {} - if not to_mcore: - # TODO: 'embed_tokens.weight', 'shared_head.head.weight' - pass + self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight', to_mcore) + self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, 'shared_head.head.weight', to_mcore) for key in ['enorm.weight', 'hnorm.weight', 'eh_proj.weight']: self._set_state_dict(mg_layer, key, hf_state_dict, key, to_mcore) - if layer_idx >= len(self.hf_layers): - layer_idx = -1 - hf_state_dict.update(self._set_layer_attn(mg_layer.transformer_layer, hf_state_dict, layer_idx, to_mcore)) - hf_state_dict.update(self._set_layer_mlp(mg_layer.transformer_layer, hf_state_dict, layer_idx, to_mcore)) + if hf_layer_idx >= len(self.hf_layers): + hf_layer_idx = -1 + hf_state_dict.update(self._set_layer_attn(mg_layer.transformer_layer, hf_state_dict, hf_layer_idx, to_mcore)) + hf_state_dict.update(self._set_layer_mlp(mg_layer.transformer_layer, hf_state_dict, hf_layer_idx, to_mcore)) self._set_state_dict(mg_layer, 'final_layernorm.weight', hf_state_dict, 'shared_head.norm.weight', to_mcore) if to_mcore: hf_state_dict = {} diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index b529a73337..19d14a7b47 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -305,6 +305,7 @@ def forward( args = get_args() labels = labels if args.task_type == 'causal_lm' else None if mcore_013: + # MTP: https://github.com/NVIDIA/Megatron-LM/issues/1661 return self._postprocess( hidden_states=hidden_states, input_ids=input_ids, diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 594561cdd8..88586edb1e 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -57,7 +57,7 @@ def get_batch_on_this_tp_rank(data, vp_stage=None): else: is_pp_first_stage = mpu.is_pipeline_first_stage() is_pp_last_stage = mpu.is_pipeline_last_stage() - if not is_pp_first_stage: + if not args.mtp_num_layers and not is_pp_first_stage: batch['input_ids'] = None if not is_pp_last_stage: batch['labels'] = None From 2bfeed50bfb3495916ca3f3b8f6075c1ec8d6c9f Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 19 Nov 2025 11:05:14 +0800 Subject: [PATCH 5/8] update --- swift/megatron/model/gpt_bridge.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 63a59430e7..63bd030acb 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -1042,8 +1042,7 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd lm_model = getattr(mg_model, 'language_model') if self.args.is_multimodal else mg_model layer_idx = 0 while layer_idx < self.args.mtp_num_layers: - res = self._convert_mtp_layer(lm_model, hf_state_dict, f'{self.hf_mtp_prefix}.', layer_idx, - to_mcore) + res = self._convert_mtp_layer(lm_model, hf_state_dict, f'{self.hf_mtp_prefix}.', layer_idx, to_mcore) layer_idx += 1 if to_mcore: yield @@ -1069,15 +1068,16 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) else: hf_state_dict = {} - self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight', to_mcore) + self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight', + to_mcore) self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, 'shared_head.head.weight', to_mcore) for key in ['enorm.weight', 'hnorm.weight', 'eh_proj.weight']: - self._set_state_dict(mg_layer, key, hf_state_dict, key, to_mcore) + self._set_state_dict(mtp_layer, key, hf_state_dict, key, to_mcore) if hf_layer_idx >= len(self.hf_layers): hf_layer_idx = -1 - hf_state_dict.update(self._set_layer_attn(mg_layer.transformer_layer, hf_state_dict, hf_layer_idx, to_mcore)) - hf_state_dict.update(self._set_layer_mlp(mg_layer.transformer_layer, hf_state_dict, hf_layer_idx, to_mcore)) - self._set_state_dict(mg_layer, 'final_layernorm.weight', hf_state_dict, 'shared_head.norm.weight', to_mcore) + hf_state_dict.update(self._set_layer_attn(mtp_layer.transformer_layer, hf_state_dict, hf_layer_idx, to_mcore)) + hf_state_dict.update(self._set_layer_mlp(mtp_layer.transformer_layer, hf_state_dict, hf_layer_idx, to_mcore)) + self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'shared_head.norm.weight', to_mcore) if to_mcore: hf_state_dict = {} else: From 4d60be4b8aea5e29e61b5d78fa3f410ca6afe1be Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 19 Nov 2025 15:15:56 +0800 Subject: [PATCH 6/8] update --- swift/llm/argument/infer_args.py | 10 ++ swift/llm/infer/infer_engine/infer_engine.py | 3 +- swift/llm/infer/infer_engine/sglang_engine.py | 9 ++ swift/megatron/init.py | 97 +++++++++++++++++++ 4 files changed, 118 insertions(+), 1 deletion(-) diff --git a/swift/llm/argument/infer_args.py b/swift/llm/argument/infer_args.py index cbf746d4dc..e0cce69ffb 100644 --- a/swift/llm/argument/infer_args.py +++ b/swift/llm/argument/infer_args.py @@ -61,6 +61,12 @@ class SglangArguments: sglang_kv_cache_dtype: str = 'auto' sglang_enable_dp_attention: bool = False sglang_disable_custom_all_reduce: bool = True + # speculative decoding + # e.g. EAGLE, EAGLE3, NEXTN + sglang_speculative_algorithm: Optional[str] = None + sglang_speculative_num_steps: Optional[int] = None + sglang_speculative_eagle_topk: Optional[int] = None + sglang_speculative_num_draft_tokens: Optional[int] = None def get_sglang_engine_kwargs(self): kwargs = { @@ -76,6 +82,10 @@ def get_sglang_engine_kwargs(self): 'kv_cache_dtype': self.sglang_kv_cache_dtype, 'enable_dp_attention': self.sglang_enable_dp_attention, 'disable_custom_all_reduce': self.sglang_disable_custom_all_reduce, + 'speculative_algorithm': self.sglang_speculative_algorithm, + 'speculative_num_steps': self.sglang_speculative_num_steps, + 'speculative_eagle_topk': self.sglang_speculative_eagle_topk, + 'speculative_num_draft_tokens': self.sglang_speculative_num_draft_tokens, } if self.task_type == 'embedding': kwargs['task_type'] = 'embedding' diff --git a/swift/llm/infer/infer_engine/infer_engine.py b/swift/llm/infer/infer_engine/infer_engine.py index 4e1c903d17..86c9583e40 100644 --- a/swift/llm/infer/infer_engine/infer_engine.py +++ b/swift/llm/infer/infer_engine/infer_engine.py @@ -32,6 +32,7 @@ def _post_init(self, template=None): self.max_model_len = self.model_info.max_model_len self.task_type = self.model_info.task_type self.config = self.model_info.config + self.max_tokens_offset = 0 if template is None: ckpt_dir = get_ckpt_dir(self.model_dir, getattr(self, 'adapters', None)) logger.info('Create the default_template for the infer_engine') @@ -220,7 +221,7 @@ def set_default_max_tokens(self, request_config: RequestConfig, inputs: Dict[str max_model_len = 8192 logger.warning( 'The current model is unable to retrieve `max_model_len`. It is set to the default value of 8192.') - max_max_tokens = max_model_len - num_tokens + max_max_tokens = max_model_len - num_tokens + self.max_tokens_offset if max_tokens is None: request_config.max_tokens = max_max_tokens elif max_max_tokens < request_config.max_tokens: diff --git a/swift/llm/infer/infer_engine/sglang_engine.py b/swift/llm/infer/infer_engine/sglang_engine.py index dd78d0b651..13049d0ada 100644 --- a/swift/llm/infer/infer_engine/sglang_engine.py +++ b/swift/llm/infer/infer_engine/sglang_engine.py @@ -48,6 +48,10 @@ def __init__( kv_cache_dtype: str = 'auto', enable_dp_attention: bool = False, disable_custom_all_reduce: bool = True, + speculative_algorithm: Optional[str] = None, + speculative_num_steps: Optional[int] = None, + speculative_eagle_topk: Optional[int] = None, + speculative_num_draft_tokens: Optional[int] = None, log_level='error', engine_kwargs: Optional[Dict[str, Any]] = None, template: Optional[Template] = None, @@ -88,6 +92,10 @@ def __init__( kv_cache_dtype=kv_cache_dtype, enable_dp_attention=enable_dp_attention, disable_custom_all_reduce=disable_custom_all_reduce, + speculative_algorithm=speculative_algorithm, + speculative_num_steps=speculative_num_steps, + speculative_eagle_topk=speculative_eagle_topk, + speculative_num_draft_tokens=speculative_num_draft_tokens, log_level=log_level, skip_tokenizer_init=True, trust_remote_code=True, @@ -98,6 +106,7 @@ def __init__( self.server_args.is_embedding = True self.engine = sgl.Engine(server_args=self.server_args) self._load_generation_config() + self.max_tokens_offset = -speculative_num_draft_tokens def _load_generation_config(self) -> None: generation_config_path = os.path.join(self.model_dir, 'generation_config.json') diff --git a/swift/megatron/init.py b/swift/megatron/init.py index d2b7b7cb95..588f3066ab 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -388,6 +388,102 @@ def build_tokenizer(args): global_vars.build_tokenizer = build_tokenizer +def _patch_mtp(): + from megatron.core import InferenceParams + from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer + from megatron.core.packed_seq_params import PackedSeqParams + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + context: torch.Tensor = None, + context_mask: torch.Tensor = None, + rotary_pos_emb: torch.Tensor = None, + rotary_pos_cos: torch.Tensor = None, + rotary_pos_sin: torch.Tensor = None, + attention_bias: torch.Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + sequence_len_offset: torch.Tensor = None, + embedding=None, + ): + """ + Execute the forward pass through the Multi-Token Prediction (MTP) layer. + + Args: + input_ids (Tensor): Input token IDs . + position_ids (Tensor): Positional IDs of the input tokens. + hidden_states (Tensor): Hidden states tensor of shape [s, b, h] where s is the + sequence length, b is the batch size, and h is the hidden size. + attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking + self-attention. + context (Tensor, optional): Context tensor for cross-attention, if applicable. + context_mask (Tensor, optional): Mask for cross-attention context, if applicable. + rotary_pos_emb (Tensor, optional): Rotary positional embeddings. + rotary_pos_cos (Tensor, optional): Cosine component of rotary positional embeddings. + rotary_pos_sin (Tensor, optional): Sine component of rotary positional embeddings. + sequence_len_offset (Tensor, optional): Offset for sequence length, if applicable. + embedding (Callable): The embedding module from gpt model to compute the decoder input. + + Returns: + Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape + [s, b, h], and optionally the updated context tensor if cross-attention is used. + """ + # TODO: Multimodal compatible; MTP initialization + # TODO: packed_seq_params offset + assert context is None, f'multi token prediction + cross attention is not yet supported.' + input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( + input_ids=input_ids, + position_ids=position_ids, + embedding=embedding, + hidden_states=hidden_states, + ) + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + apply_rope_fusion = self.config.apply_rope_fusion + self.config.apply_rope_fusion = False + if packed_seq and not self.config.apply_rope_fusion: + assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' + rotary_pos_emb = rotary_pos_emb[position_ids[0]] + if self.config.recompute_granularity == 'full' and self.training: + hidden_states = self._checkpointed_forward( + self._proj_and_transformer_layer, + hidden_states=hidden_states, + decoder_input=decoder_input, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + else: + hidden_states = self._proj_and_transformer_layer( + hidden_states=hidden_states, + decoder_input=decoder_input, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + self.config.apply_rope_fusion = apply_rope_fusion + return hidden_states, input_ids, position_ids + + MultiTokenPredictionLayer.forward = forward + + def _patch_peft_ModulesToSaveWrapper(): if version.parse(peft.__version__) >= version.parse('0.16'): from peft.utils import other as peft_module @@ -686,6 +782,7 @@ def _patch_megatron(): _patch_build_train_valid_test_datasets() _patch_mrope() _patch_megatron_tokenizer() + _patch_mtp() logging.root.setLevel(logging_level) # revert logger level from swift.megatron import tuners # patch lora try: From 8d4fe50cd4ec56f3cbd7e6ca98490797b18ab049 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 19 Nov 2025 17:17:16 +0800 Subject: [PATCH 7/8] update --- swift/megatron/init.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 588f3066ab..d2797628b2 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -434,7 +434,7 @@ def forward( """ # TODO: Multimodal compatible; MTP initialization # TODO: packed_seq_params offset - assert context is None, f'multi token prediction + cross attention is not yet supported.' + assert context is None, 'multi token prediction + cross attention is not yet supported.' input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( input_ids=input_ids, position_ids=position_ids, From ed5646c5e767c62a553bb6ea7cb2e632c9d3800c Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 19 Nov 2025 17:29:04 +0800 Subject: [PATCH 8/8] update --- docs/source/Instruction/Supported-models-and-datasets.md | 4 +++- docs/source_en/Instruction/Supported-models-and-datasets.md | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/source/Instruction/Supported-models-and-datasets.md b/docs/source/Instruction/Supported-models-and-datasets.md index 1b4e35058f..70cb341543 100644 --- a/docs/source/Instruction/Supported-models-and-datasets.md +++ b/docs/source/Instruction/Supported-models-and-datasets.md @@ -1133,6 +1133,7 @@ |-|default|huge dataset|-|pretrain, quality|[allenai/c4](https://huggingface.co/datasets/allenai/c4)| |[bespokelabs/Bespoke-Stratos-17k](https://modelscope.cn/datasets/bespokelabs/Bespoke-Stratos-17k)|default|16710|480.7±236.1, min=266, max=3556|chat, sft, cot, r1|[bespokelabs/Bespoke-Stratos-17k](https://huggingface.co/datasets/bespokelabs/Bespoke-Stratos-17k)| |-|default|huge dataset|-|pretrain, quality|[cerebras/SlimPajama-627B](https://huggingface.co/datasets/cerebras/SlimPajama-627B)| +|[clip-benchmark/wds_voc2007_multilabel](https://modelscope.cn/datasets/clip-benchmark/wds_voc2007_multilabel)|default|2501|112.0±0.0, min=112, max=112|multilabel, multi-modal|[clip-benchmark/wds_voc2007_multilabel](https://huggingface.co/datasets/clip-benchmark/wds_voc2007_multilabel)| |[codefuse-ai/CodeExercise-Python-27k](https://modelscope.cn/datasets/codefuse-ai/CodeExercise-Python-27k)|default|27224|337.3±154.2, min=90, max=2826|chat, coding, 🔥|-| |[codefuse-ai/Evol-instruction-66k](https://modelscope.cn/datasets/codefuse-ai/Evol-instruction-66k)|default|66862|440.1±208.4, min=46, max=2661|chat, coding, 🔥|-| |[damo/MSAgent-Bench](https://modelscope.cn/datasets/damo/MSAgent-Bench)|default
mini|638149|859.2±460.1, min=38, max=3479|chat, agent, multi-round|-| @@ -1160,6 +1161,7 @@ |[modelscope/clue](https://modelscope.cn/datasets/modelscope/clue)|cmnli|391783|81.6±16.0, min=54, max=157|text-generation, classification|[clue](https://huggingface.co/datasets/clue)| |[modelscope/coco_2014_caption](https://modelscope.cn/datasets/modelscope/coco_2014_caption)|train
validation|454617|389.6±68.4, min=70, max=587|chat, multi-modal, vision, 🔥|-| |[modelscope/gsm8k](https://modelscope.cn/datasets/modelscope/gsm8k)|main|7473|88.6±21.6, min=41, max=241|qa, math|-| +|[open-r1/DAPO-Math-17k-Processed](https://modelscope.cn/datasets/open-r1/DAPO-Math-17k-Processed)|all|17398|122.3±65.2, min=41, max=1517|math, rlvr|[open-r1/DAPO-Math-17k-Processed](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed)| |[open-r1/verifiable-coding-problems-python](https://modelscope.cn/datasets/open-r1/verifiable-coding-problems-python)|default|35735|559.0±255.2, min=74, max=6191|grpo, code|[open-r1/verifiable-coding-problems-python](https://huggingface.co/datasets/open-r1/verifiable-coding-problems-python)| |[open-r1/verifiable-coding-problems-python-10k](https://modelscope.cn/datasets/open-r1/verifiable-coding-problems-python-10k)|default|1800|581.6±233.4, min=136, max=2022|grpo, code|[open-r1/verifiable-coding-problems-python-10k](https://huggingface.co/datasets/open-r1/verifiable-coding-problems-python-10k)| |[open-r1/verifiable-coding-problems-python-10k_decontaminated](https://modelscope.cn/datasets/open-r1/verifiable-coding-problems-python-10k_decontaminated)|default|1574|575.7±234.3, min=136, max=2022|grpo, code|[open-r1/verifiable-coding-problems-python-10k_decontaminated](https://huggingface.co/datasets/open-r1/verifiable-coding-problems-python-10k_decontaminated)| @@ -1189,7 +1191,7 @@ |[swift/RedPajama-Data-V2](https://modelscope.cn/datasets/swift/RedPajama-Data-V2)|default|huge dataset|-|pretrain, quality|[togethercomputer/RedPajama-Data-V2](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)| |[swift/ScienceQA](https://modelscope.cn/datasets/swift/ScienceQA)|default|16967|101.7±55.8, min=32, max=620|multi-modal, science, vqa, quality|[derek-thomas/ScienceQA](https://huggingface.co/datasets/derek-thomas/ScienceQA)| |[swift/SlimOrca](https://modelscope.cn/datasets/swift/SlimOrca)|default|517982|405.5±442.1, min=47, max=8312|quality, en|[Open-Orca/SlimOrca](https://huggingface.co/datasets/Open-Orca/SlimOrca)| -|[swift/TextCaps](https://modelscope.cn/datasets/swift/TextCaps)|default
emb|huge dataset|-|multi-modal, en, caption, quality|[HuggingFaceM4/TextCaps](https://huggingface.co/datasets/HuggingFaceM4/TextCaps)| +|[swift/TextCaps](https://modelscope.cn/datasets/swift/TextCaps)|default
emb
rerank|huge dataset|-|multi-modal, en, caption, quality|[HuggingFaceM4/TextCaps](https://huggingface.co/datasets/HuggingFaceM4/TextCaps)| |[swift/ToolBench](https://modelscope.cn/datasets/swift/ToolBench)|default|124345|2251.7±1039.8, min=641, max=9451|chat, agent, multi-round|-| |[swift/VQAv2](https://modelscope.cn/datasets/swift/VQAv2)|default|huge dataset|-|en, vqa, quality|[HuggingFaceM4/VQAv2](https://huggingface.co/datasets/HuggingFaceM4/VQAv2)| |[swift/VideoChatGPT](https://modelscope.cn/datasets/swift/VideoChatGPT)|Generic
Temporal
Consistency|3206|87.4±48.3, min=31, max=398|chat, multi-modal, video, 🔥|[lmms-lab/VideoChatGPT](https://huggingface.co/datasets/lmms-lab/VideoChatGPT)| diff --git a/docs/source_en/Instruction/Supported-models-and-datasets.md b/docs/source_en/Instruction/Supported-models-and-datasets.md index 808cec5a85..6d5059dd93 100644 --- a/docs/source_en/Instruction/Supported-models-and-datasets.md +++ b/docs/source_en/Instruction/Supported-models-and-datasets.md @@ -1134,6 +1134,7 @@ The table below introduces information about the datasets integrated with ms-swi |-|default|huge dataset|-|pretrain, quality|[allenai/c4](https://huggingface.co/datasets/allenai/c4)| |[bespokelabs/Bespoke-Stratos-17k](https://modelscope.cn/datasets/bespokelabs/Bespoke-Stratos-17k)|default|16710|480.7±236.1, min=266, max=3556|chat, sft, cot, r1|[bespokelabs/Bespoke-Stratos-17k](https://huggingface.co/datasets/bespokelabs/Bespoke-Stratos-17k)| |-|default|huge dataset|-|pretrain, quality|[cerebras/SlimPajama-627B](https://huggingface.co/datasets/cerebras/SlimPajama-627B)| +|[clip-benchmark/wds_voc2007_multilabel](https://modelscope.cn/datasets/clip-benchmark/wds_voc2007_multilabel)|default|2501|112.0±0.0, min=112, max=112|multilabel, multi-modal|[clip-benchmark/wds_voc2007_multilabel](https://huggingface.co/datasets/clip-benchmark/wds_voc2007_multilabel)| |[codefuse-ai/CodeExercise-Python-27k](https://modelscope.cn/datasets/codefuse-ai/CodeExercise-Python-27k)|default|27224|337.3±154.2, min=90, max=2826|chat, coding, 🔥|-| |[codefuse-ai/Evol-instruction-66k](https://modelscope.cn/datasets/codefuse-ai/Evol-instruction-66k)|default|66862|440.1±208.4, min=46, max=2661|chat, coding, 🔥|-| |[damo/MSAgent-Bench](https://modelscope.cn/datasets/damo/MSAgent-Bench)|default
mini|638149|859.2±460.1, min=38, max=3479|chat, agent, multi-round|-| @@ -1161,6 +1162,7 @@ The table below introduces information about the datasets integrated with ms-swi |[modelscope/clue](https://modelscope.cn/datasets/modelscope/clue)|cmnli|391783|81.6±16.0, min=54, max=157|text-generation, classification|[clue](https://huggingface.co/datasets/clue)| |[modelscope/coco_2014_caption](https://modelscope.cn/datasets/modelscope/coco_2014_caption)|train
validation|454617|389.6±68.4, min=70, max=587|chat, multi-modal, vision, 🔥|-| |[modelscope/gsm8k](https://modelscope.cn/datasets/modelscope/gsm8k)|main|7473|88.6±21.6, min=41, max=241|qa, math|-| +|[open-r1/DAPO-Math-17k-Processed](https://modelscope.cn/datasets/open-r1/DAPO-Math-17k-Processed)|all|17398|122.3±65.2, min=41, max=1517|math, rlvr|[open-r1/DAPO-Math-17k-Processed](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed)| |[open-r1/verifiable-coding-problems-python](https://modelscope.cn/datasets/open-r1/verifiable-coding-problems-python)|default|35735|559.0±255.2, min=74, max=6191|grpo, code|[open-r1/verifiable-coding-problems-python](https://huggingface.co/datasets/open-r1/verifiable-coding-problems-python)| |[open-r1/verifiable-coding-problems-python-10k](https://modelscope.cn/datasets/open-r1/verifiable-coding-problems-python-10k)|default|1800|581.6±233.4, min=136, max=2022|grpo, code|[open-r1/verifiable-coding-problems-python-10k](https://huggingface.co/datasets/open-r1/verifiable-coding-problems-python-10k)| |[open-r1/verifiable-coding-problems-python-10k_decontaminated](https://modelscope.cn/datasets/open-r1/verifiable-coding-problems-python-10k_decontaminated)|default|1574|575.7±234.3, min=136, max=2022|grpo, code|[open-r1/verifiable-coding-problems-python-10k_decontaminated](https://huggingface.co/datasets/open-r1/verifiable-coding-problems-python-10k_decontaminated)| @@ -1190,7 +1192,7 @@ The table below introduces information about the datasets integrated with ms-swi |[swift/RedPajama-Data-V2](https://modelscope.cn/datasets/swift/RedPajama-Data-V2)|default|huge dataset|-|pretrain, quality|[togethercomputer/RedPajama-Data-V2](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)| |[swift/ScienceQA](https://modelscope.cn/datasets/swift/ScienceQA)|default|16967|101.7±55.8, min=32, max=620|multi-modal, science, vqa, quality|[derek-thomas/ScienceQA](https://huggingface.co/datasets/derek-thomas/ScienceQA)| |[swift/SlimOrca](https://modelscope.cn/datasets/swift/SlimOrca)|default|517982|405.5±442.1, min=47, max=8312|quality, en|[Open-Orca/SlimOrca](https://huggingface.co/datasets/Open-Orca/SlimOrca)| -|[swift/TextCaps](https://modelscope.cn/datasets/swift/TextCaps)|default
emb|huge dataset|-|multi-modal, en, caption, quality|[HuggingFaceM4/TextCaps](https://huggingface.co/datasets/HuggingFaceM4/TextCaps)| +|[swift/TextCaps](https://modelscope.cn/datasets/swift/TextCaps)|default
emb
rerank|huge dataset|-|multi-modal, en, caption, quality|[HuggingFaceM4/TextCaps](https://huggingface.co/datasets/HuggingFaceM4/TextCaps)| |[swift/ToolBench](https://modelscope.cn/datasets/swift/ToolBench)|default|124345|2251.7±1039.8, min=641, max=9451|chat, agent, multi-round|-| |[swift/VQAv2](https://modelscope.cn/datasets/swift/VQAv2)|default|huge dataset|-|en, vqa, quality|[HuggingFaceM4/VQAv2](https://huggingface.co/datasets/HuggingFaceM4/VQAv2)| |[swift/VideoChatGPT](https://modelscope.cn/datasets/swift/VideoChatGPT)|Generic
Temporal
Consistency|3206|87.4±48.3, min=31, max=398|chat, multi-modal, video, 🔥|[lmms-lab/VideoChatGPT](https://huggingface.co/datasets/lmms-lab/VideoChatGPT)|