diff --git a/README.md b/README.md
index b263879aec..455ff142d3 100644
--- a/README.md
+++ b/README.md
@@ -65,7 +65,7 @@ You can contact us and communicate with us by adding our group:
- **Quantization Training**: Supports training quantized models like BNB, AWQ, GPTQ, AQLM, HQQ, EETQ.
- 🍊 **RLHF Training**: Supports human alignment training methods such as DPO, GRPO, RM, PPO, GKD, KTO, CPO, SimPO, ORPO for both pure text and multi-modal large models.
- 🍓 **Multi-Modal Training**: Supports training on different modalities like images, videos, and audio, for tasks like VQA, captioning, OCR, and grounding.
-- 🥥 **Megatron Parallelism**: Supports accelerating CPT/SFT/DPO/KTO/RM using Megatron parallelism techniques, currently compatible with 200+ pure text large models, 100+ multi-modal large models.
+- 🥥 **Megatron Parallelism**: Supports accelerating CPT/SFT/GRPO/DPO/KTO/RM using Megatron parallelism techniques, currently compatible with 200+ pure text large models, 100+ multi-modal large models.
- **Interface Training**: Provides capabilities for training, inference, evaluation, quantization through an interface, completing the whole large model pipeline.
- **Plugin and Extension**: Supports custom model and dataset extensions, as well as customization of components like loss, metric, trainer, loss-scale, callback, optimizer.
- 🍉 **Toolbox Capabilities**: Offers not only training support for large models and multi-modal large models but also covers the entire process of inference, evaluation, quantization, and deployment.
@@ -75,6 +75,7 @@ You can contact us and communicate with us by adding our group:
## 🎉 News
+- 🎁 2025.11.14: Megatron GRPO is now available! Check out the [docs](./docs/source_en/Megatron-SWIFT/GRPO.md) and [examples](examples/megatron/grpo).
- 🎁 2025.11.04: Support for [Mcore-Bridge](docs/source_en/Megatron-SWIFT/Mcore-Bridge.md), making Megatron training as simple and easy to use as transformers.
- 🎁 2025.10.28: Ray [here](docs/source_en/Instruction/Ray.md).
- 🎁 2025.10.28: Support [use yaml](examples/yaml) to configure command line parameters.
diff --git a/README_CN.md b/README_CN.md
index da2b914169..08a7f1b93d 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -62,7 +62,7 @@
- **量化训练**:支持对BNB、AWQ、GPTQ、AQLM、HQQ、EETQ量化模型进行训练。
- 🍊 **RLHF训练**:支持纯文本大模型和多模态大模型的DPO、GRPO、RM、PPO、GKD、KTO、CPO、SimPO、ORPO等人类对齐训练方法。
- 🍓 **多模态训练**:支持对图像、视频和语音不同模态模型进行训练,支持VQA、Caption、OCR、Grounding任务的训练。
-- 🥥 **Megatron并行技术**:支持使用Megatron并行技术对CPT/SFT/DPO/KTO/RM进行加速,现支持200+纯文本大模型和100+多模态大模型。
+- 🥥 **Megatron并行技术**:支持使用Megatron并行技术对CPT/SFT/GRPO/DPO/KTO/RM进行加速,现支持200+纯文本大模型和100+多模态大模型。
- **界面训练**:以界面的方式提供训练、推理、评测、量化的能力,完成大模型的全链路。
- **插件化与拓展**:支持自定义模型和数据集拓展,支持对loss、metric、trainer、loss-scale、callback、optimizer等组件进行自定义。
- 🍉 **工具箱能力**:不仅提供大模型和多模态大模型的训练支持,还涵盖其推理、评测、量化和部署全流程。
@@ -71,6 +71,7 @@
- **模型量化**:支持AWQ、GPTQ、FP8和BNB的量化导出,导出的模型支持使用vLLM/SGLang/LmDeploy推理加速,并支持继续训练。
## 🎉 新闻
+- 🎁 2025.11.14: Megatron GRPO现已支持!查看[文档](./docs/source/Megatron-SWIFT/GRPO.md)和[示例](examples/megatron/grpo)。
- 🎁 2025.11.04: 支持[Mcore-Bridge](docs/source/Megatron-SWIFT/Mcore-Bridge.md),使Megatron训练像transformers一样简单易用。
- 🎁 2025.10.28: Ray [已支持](docs/source/Instruction/Ray.md)。
- 🎁 2025.10.28: 已支持[使用yaml](examples/yaml)配置命令行参数。
diff --git a/docs/source/Instruction/Command-line-parameters.md b/docs/source/Instruction/Command-line-parameters.md
index 53e4d349c8..2a2f9db2d4 100644
--- a/docs/source/Instruction/Command-line-parameters.md
+++ b/docs/source/Instruction/Command-line-parameters.md
@@ -566,13 +566,13 @@ reward模型参数将在PPO、GRPO中使用。
- use_vllm: 是否使用 vLLM 作为 GRPO 生成的 infer_backend,默认为False。
- vllm_mode: vLLM 集成模式,可选项为 `server` 和 `colocate`。server 模式使用 `swift rollout` 拉起的 vLLM 服务器进行采样,colocate 模式在程序内部署 vLLM。使用server端时,
- vllm_mode server 参数
+ - vllm_server_host: vLLM server host地址,默认为None。
+ - vllm_server_port: vLLM server 服务端口,默认为8000。
- vllm_server_base_url: vLLM server的Base URL(比如 http://local_host:8000), 默认为None。设置后,忽略host和port设置。
- - vllm_server_host:vLLM server host地址,默认为None。
- - vllm_server_port vLLM server 服务端口,默认为8000。
- - vllm_server_timeout 连接vLLM server的超时时间,默认为 240s。
+ - vllm_server_timeout: 连接vLLM server的超时时间,默认为 240s。
- vllm_server_pass_dataset: 透传额外的数据集信息到vLLM server,用于多轮训练。
- async_generate: 异步rollout以提高训练速度,注意开启时采样会使用上一轮更新的模型进行采样,不支持多轮场景。默认`false`.
- - SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE:环境变量,用于控制权重同步时的传输桶大小(bucket size),适用于 Server Mode 下的全参数训练,单位为 MB,默认值为 512 MB。
+ - SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE: 环境变量,用于控制权重同步时的传输桶大小(bucket size),适用于 Server Mode 下的全参数训练,单位为 MB,默认值为 512 MB。
- vllm_mode colocate 参数(更多参数支持参考[vLLM参数](#vLLM参数)。)
- vllm_gpu_memory_utilization: vllm透传参数,默认为0.9。
- vllm_max_model_len: vllm透传参数,默认为None。
@@ -581,7 +581,7 @@ reward模型参数将在PPO、GRPO中使用。
- vllm_enable_prefix_caching: vllm透传参数,默认为True。
- vllm_tensor_parallel_size: tp并行数,默认为`1`。
- vllm_enable_lora: 支持vLLM Engine 加载 LoRA adapter,默认为False。用于加速LoRA训练的权重同步,具体参考[文档](./GRPO/GetStarted/GRPO.md#权重同步加速)。
- - sleep_level: 训练时释放 vLLM 显存,可选项为[0, 1], 默认为0,不释放。
+ - sleep_level: 训练时释放 vLLM 显存,可选项为[0, 1, 2], 默认为0,不释放。
- offload_optimizer: 是否在vLLM推理时offload optimizer参数,默认为False。
- offload_model: 是否在vLLM推理时 offload 模型,默认为False。
- completion_length_limit_scope: 在多轮对话中,`max_completion_length` 的限制范围。
@@ -593,7 +593,7 @@ reward模型参数将在PPO、GRPO中使用。
- max_resample_times:dynamic_sample设置下限制重采样次数,默认3次。
- overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。
- delta: [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291)中双侧 GRPO 上界裁剪值。若设置,建议大于 1 + epsilon。默认为None。
-- importance_sampling_level: 控制重要性采样比计算,可选项为 `token` 和 `sequence`,`token` 模式下保留原始的每个 token 的对数概率比,`sequence` 模式下则会对序列中所有有效 token 的对数概率比进行平均。[GSPO论文](https://www.arxiv.org/abs/2507.18071)中使用sequence级别计算来稳定训练,默认为`token`。
+- importance_sampling_level: 控制重要性采样比计算,可选项为 `token` 和 `sequence`,`token` 模式下保留原始的每个 token 的对数概率比,`sequence` 模式下则会对序列中所有有效 token 的对数概率比进行平均。[GSPO论文](https://arxiv.org/abs/2507.18071)中使用sequence级别计算来稳定训练,默认为`token`。
- advantage_estimator: 优势计算函数,默认为 `grpo`,即计算组内相对优势,可选项为 `grpo`、[`rloo`](./GRPO/AdvancedResearch/RLOO.md)、[`reinforce_plus_plus`](./GRPO/AdvancedResearch/REINFORCEPP.md)。
- kl_in_reward: 控制 KL 散度正则项的处理位置;`false`表示作为损失函数的独立正则项,`true`表示将 KL 直接并入奖励(从奖励中扣除)。默认情况与advantage_estimator绑定,`grpo`下默认为`false`,`rloo` 和 `reinforce_plus_plus` 下默认为 `true`。
- scale_rewards:指定奖励的缩放策略。可选值包括 `group`(按组内标准差缩放)、`batch`(按整个批次的标准差缩放)、`none`(不进行缩放)。在 ms-swift < 3.10 版本中,该参数为布尔类型,`true` 对应 `group`,`false` 对应 `none`。默认值与 `advantage_estimator` 绑定:`grpo` 对应 `group`,`rloo` 对应 `none`,`reinforce_plus_plus` 对应 `batch`。
@@ -606,6 +606,8 @@ reward模型参数将在PPO、GRPO中使用。
- top_entropy_quantile: 仅对熵值处于前指定分位的 token 参与损失计算,默认为1.0,即不过滤低熵 token,具体参考[文档](./GRPO/AdvancedResearch/entropy_mask.md)
- log_entropy: 记录训练中的熵值变化动态,默认为False,具体参考[文档](./GRPO/GetStarted/GRPO.md#logged-metrics)
+##### 奖励函数参数
+内置的奖励函数参考[文档](./GRPO/DeveloperGuide/reward_function.md)
cosine 奖励参数
- cosine_min_len_value_wrong:cosine 奖励函数参数,生成错误答案时,最小长度对应的奖励值。默认值为-0.5。
- cosine_max_len_value_wrong:生成错误答案时,最大长度对应的奖励值。默认值为0.0。
diff --git a/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md b/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md
index 1f21f2abfe..9bc9df2f80 100644
--- a/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md
+++ b/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md
@@ -2,7 +2,7 @@
**版本依赖**:ms-swift>=3.7
-[Group Sequence Policy Optimization](https://www.arxiv.org/abs/2507.18071)中指出GRPO在计算重要性采样权重时,是在token级别进行操作的。然而,这种做法由于每个token仅采样一次,无法实现有效的分布校正,反而会在模型训练过程中引入高方差噪声,极易导致模型的梯度估计不稳定,最终造成模型训练的崩塌。因此,论文认为,优化目标的单位应该与奖励的单位保持一致。由于奖励通常是在序列级别(即完整生成的回复)给出的,因此更合理的做法是将 off-policy 校正和优化也提升到序列级别,而非 token 级别。以下是三种计算策略对比:
+[Group Sequence Policy Optimization](https://arxiv.org/abs/2507.18071)中指出GRPO在计算重要性采样权重时,是在token级别进行操作的。然而,这种做法由于每个token仅采样一次,无法实现有效的分布校正,反而会在模型训练过程中引入高方差噪声,极易导致模型的梯度估计不稳定,最终造成模型训练的崩塌。因此,论文认为,优化目标的单位应该与奖励的单位保持一致。由于奖励通常是在序列级别(即完整生成的回复)给出的,因此更合理的做法是将 off-policy 校正和优化也提升到序列级别,而非 token 级别。以下是三种计算策略对比:
1. GRPO
对每个 token 独立计算重要性采样比,具体公式为
@@ -54,7 +54,7 @@ importance_weights = torch.exp(log_importance_weights)
- `importance_sampling_level sequence` (GSPO)
- `importance_sampling_level sequence_token` (GSPO-token)
-其中 sequence_token 要求 ms-swift > 3.7 (源码安装)
+其中 sequence_token 要求 ms-swift >= 3.8
论文其他超参
```bash
diff --git a/docs/source/Instruction/Use-tuners.md b/docs/source/Instruction/Use-tuners.md
index c84ca6fe0c..7461877fc8 100644
--- a/docs/source/Instruction/Use-tuners.md
+++ b/docs/source/Instruction/Use-tuners.md
@@ -15,7 +15,7 @@ tuner是指附加在模型上的额外结构部分,用于减少训练参数量
- Adapter: [Parameter-Efficient Transfer Learning for NLP](http://arxiv.org/abs/1902.00751)
- Vision Prompt Tuning: [Visual Prompt Tuning](https://arxiv.org/abs/2203.12119)
- Side: [Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks](https://arxiv.org/abs/1912.13503)
-- Res-Tuning: [Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone](https://arxiv.org/abs/2310.19859) < [arXiv](https://arxiv.org/abs/2310.19859) | [Project Page](https://res-tuning.github.io/) | [Usage](ResTuning.md) >
+- Res-Tuning: [Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone](https://arxiv.org/abs/2310.19859) < [arXiv](https://arxiv.org/abs/2310.19859) | [Project Page](https://res-tuning.github.io/) >
- [PEFT](https://github.com/huggingface/peft)提供的tuners, 如AdaLoRA、DoRA、Fourierft等
## 接口列表
diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md
index f551b9d906..5c75aa28c0 100644
--- a/docs/source/Megatron-SWIFT/Command-line-parameters.md
+++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md
@@ -246,36 +246,13 @@ lora训练:
- lora_bias: 默认为`'none'`,可以选择的值: 'none'、'all'。如果你要将bias全都设置为可训练,你可以设置为`'all'`。
- use_rslora: 默认为`False`,是否使用`RS-LoRA`。
-
-**DPO参数**:
-- ref_load: ref_model的加载路径。采用DPO/KTO算法且使用全参数训练时需要传入。默认为None,即设置为`load`。
-- ref_adapter_load: 加载ref_adapter的权重路径,默认为None。若你要使用SFT产生的LoRA权重进行DPO,请使用"ms-swift>=3.8",并在训练时设置`--adapter_load sft_ckpt --ref_adapter_load sft_ckpt --finetune true`。若是此场景的断点续训,则设置`--adapter_load rlhf_ckpt --ref_adapter_load sft_ckpt --finetune false`。
-- beta: 含义与[TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig)相同。控制与参考模型偏差程度的参数。beta值越高,表示与参考模型的偏差越小。对于 IPO 损失函数 (loss_type="ipo"),beta是[论文](https://huggingface.co/papers/2310.12036)中所指的正则化参数。默认为0.1。
-- 🔥rpo_alpha: 来自[RPO 论文](https://huggingface.co/papers/2404.19733)中的参数,用于控制损失函数中NLL项的权重(即SFT损失),`loss = dpo_loss + rpo_alpha * sft_loss`,论文中推荐设置为`1.`。默认为`None`,即默认不引入sft_loss。
- - **注意**:在"ms-swift<3.8",其默认值为`1.`。在"ms-swift>=3.8"该默认值修改为`None`。
-- reference_free: 是否忽略提供的参考模型,并隐式地使用一个对所有响应赋予相等概率的参考模型。默认为False。
-- label_smoothing: 默认为0.。
-- f_divergence_type: 默认为`reverse_kl`。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer)。
-- loss_type: 默认为'sigmoid'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer#loss-functions)。
-
-**KTO参数**:
-- ref_load: 含义同DPO。
-- ref_adapter_load: 含义同DPO。
-- beta: 控制与 ref_model 偏离程度的参数。较高的 beta 表示与 ref_model 偏离更小。默认为`0.1`。
-- loss_type: 默认为'kto'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/kto_trainer#trl.KTOConfig.loss_type)。
-- desirable_weight: 抵消 desirable 和 undesirable 数量不均衡的影响,对 desirable 损失按该系数进行加权,默认为`1.`。
-- undesirable_weight: 抵消 desirable 和 undesirable 数量不均衡的影响,对 undesirable 损失按该系数进行加权,默认为`1.`。
-
-**RM参数**:
-- center_rewards_coefficient: 用于激励奖励模型输出均值为零的奖励的系数,具体查看这篇[论文](https://huggingface.co/papers/2312.09244)。推荐值:0.01。
-
**Mcore-Bridge参数**
- 🔥load_safetensors: 默认为False,是否直接从safetensors加载权重。
- 🔥save_safetensors: 默认为False,是否直接保存成safetensors权重。注意,若该参数设置为True,则不会存储优化器权重、随机数状态等断点续训内容。
- model: safetensors权重的model_id或者model_path。默认为None。
- model_type: 模型类型。介绍参考[ms-swift命令行参数文档](../Instruction/Command-line-parameters.md)。
- adapters: safetensors格式的LoRA增量权重的adapter_id或者adapter_path。默认为`[]`。
-- ref_model: ref_model safetensors权重的model_id或者model_path。采用dpo、kto算法且使用全参数训练时需要传入。默认为None,设置为`--model`。
+- ref_model: ref_model safetensors权重的model_id或者model_path。采用grpo、dpo、kto算法且使用全参数训练时需要传入。默认为None,设置为`--model`。
- ref_adapters: ref_adapters safetensors权重的adapter_id或者adapter_path的列表(目前只支持长度为1),默认为`[]`。
- use_hf: 控制模型下载、数据集下载、模型推送使用ModelScope还是HuggingFace。默认为False,使用ModelScope。
- hub_token: hub token. modelscope的hub token可以查看[这里](https://modelscope.cn/my/myaccesstoken)。默认为None。
@@ -318,11 +295,79 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用
## RLHF参数
除了继承训练参数外,还支持以下参数:
-- 🔥rlhf_type: 默认为'dpo'。目前可选择为'dpo'、'kto'和'rm'。
+- 🔥rlhf_type: 默认为'dpo'。目前可选择为'dpo'、'grpo'、'kto'和'rm'。
- loss_scale: 覆盖[基本参数](../Instruction/Command-line-parameters.md)中的loss_scale。默认为'last_round'。
- calculate_per_token_loss: 覆盖Megatron参数,默认为False。
+### DPO参数
+- ref_load: ref_model的加载路径。采用DPO/GRPO/KTO算法且使用全参数训练时需要传入。默认为None,即设置为`load`。
+- ref_adapter_load: 加载ref_adapter的权重路径,默认为None。若你要使用SFT产生的LoRA权重进行DPO,请使用"ms-swift>=3.8",并在训练时设置`--adapter_load sft_ckpt --ref_adapter_load sft_ckpt --finetune true`。若是此场景的断点续训,则设置`--adapter_load rlhf_ckpt --ref_adapter_load sft_ckpt --finetune false`。
+- beta: 含义与[TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig)相同。控制与参考模型偏差程度的参数。beta值越高,表示与参考模型的偏差越小。对于 IPO 损失函数 (loss_type="ipo"),beta是[论文](https://huggingface.co/papers/2310.12036)中所指的正则化参数。默认为0.1。
+- 🔥rpo_alpha: 来自[RPO 论文](https://huggingface.co/papers/2404.19733)中的参数,用于控制损失函数中NLL项的权重(即SFT损失),`loss = dpo_loss + rpo_alpha * sft_loss`,论文中推荐设置为`1.`。默认为`None`,即默认不引入sft_loss。
+ - **注意**:在"ms-swift<3.8",其默认值为`1.`。在"ms-swift>=3.8"该默认值修改为`None`。
+- reference_free: 是否忽略提供的参考模型,并隐式地使用一个对所有响应赋予相等概率的参考模型。默认为False。
+- label_smoothing: 默认为0.。
+- f_divergence_type: 默认为`reverse_kl`。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer)。
+- loss_type: 默认为'sigmoid'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer#loss-functions)。
+
+### KTO参数
+- ref_load: 含义同DPO。
+- ref_adapter_load: 含义同DPO。
+- beta: 控制与 ref_model 偏离程度的参数。较高的 beta 表示与 ref_model 偏离更小。默认为`0.1`。
+- loss_type: 默认为'kto'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/kto_trainer#trl.KTOConfig.loss_type)。
+- desirable_weight: 抵消 desirable 和 undesirable 数量不均衡的影响,对 desirable 损失按该系数进行加权,默认为`1.`。
+- undesirable_weight: 抵消 desirable 和 undesirable 数量不均衡的影响,对 undesirable 损失按该系数进行加权,默认为`1.`。
+
+### RM参数
+- center_rewards_coefficient: 用于激励奖励模型输出均值为零的奖励的系数,具体查看这篇[论文](https://huggingface.co/papers/2312.09244)。推荐值:0.01。
+
+### GRPO参数
+- ref_load: 含义同DPO。
+- ref_adapter_load: 含义同DPO。
+- beta: KL正则系数,默认为0.04,设置为0时不加载ref model。
+- micro_batch_size: 每个device的批次大小,默认为1。
+- global_batch_size: 总批次大小,等价于`micro_batch_size*数据并行大小*梯度累加步数`。默认为16。
+- steps_per_generation:每轮生成的优化步数,即采样批量大小相对global_batch_size的倍数,默认为1。
+- generation_batch_size: 采样批量大小,需要是global_batch_size的倍数,默认等于global_batch_size*steps_per_generation。
+- num_generations: 每个prompt采样的数量,论文中的G值,默认为8。
+- reward_funcs: GRPO算法奖励函数,可选项为`accuracy`、`format`、`cosine`、`repetition`和`soft_overlong`,见swift/plugin/orm.py。你也可以在plugin中自定义自己的奖励函数。默认为`[]`。
+- reward_weights: 每个奖励函数的权重。必须与奖励函数和奖励模型的总数量匹配。默认为 None,即所有奖励的权重都相等,为`1.0`。
+ - 提示:如果GRPO训练中包含`--reward_model`,则其加在奖励函数的最后位置。
+- loss_type: loss 归一化的类型,可选项为['grpo', 'bnpo', 'dr_grpo'], 默认为'grpo', 具体查看该[pr](https://github.com/huggingface/trl/pull/3256#discussion_r2033213348)。
+- log_completions: 是否记录训练中的模型生成内容,默认为False。
+- vllm_mode: vLLM 集成模式,可选项为 `server` 和 `colocate`。server 模式使用 `swift rollout` 拉起的 vLLM 服务器进行采样,colocate 模式在程序内部署 vLLM。使用server端时,
+- vllm_mode server 参数
+ - vllm_server_host: vLLM server host地址,默认为None。
+ - vllm_server_port: vLLM server 服务端口,默认为8000。
+ - vllm_server_base_url: vLLM server的Base URL(比如 http://local_host:8000), 默认为None。设置后,忽略host和port设置。
+ - vllm_server_timeout: 连接vLLM server的超时时间,默认为 240s。
+ - vllm_server_pass_dataset: 透传额外的数据集信息到vLLM server,用于多轮训练。
+ - async_generate: 异步rollout以提高训练速度,注意开启时采样会使用上一轮更新的模型进行采样,不支持多轮场景。默认`false`.
+ - SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE: 环境变量,用于控制权重同步时的传输桶大小(bucket size),适用于 Server Mode 下的全参数训练,单位为 MB,默认值为 512 MB。
+- vllm_mode colocate 参数(更多参数支持参考[vLLM参数](#vLLM参数)。)
+ - vllm_gpu_memory_utilization: vllm透传参数,默认为0.9。
+ - vllm_max_model_len: vllm透传参数,默认为None。
+ - vllm_enforce_eager: vllm透传参数,默认为False。
+ - vllm_limit_mm_per_prompt: vllm透传参数,默认为None。
+ - vllm_enable_prefix_caching: vllm透传参数,默认为True。
+ - vllm_tensor_parallel_size: tp并行数,默认为`1`。
+ - vllm_enable_lora: 支持vLLM Engine 加载 LoRA adapter,默认为False。用于加速LoRA训练的权重同步,具体参考[文档](../Instruction/GRPO/GetStarted/GRPO.md#权重同步加速)。
+ - sleep_level: 训练时释放 vLLM 显存,可选项为[0, 1, 2], 默认为0,不释放。
+ - offload_optimizer: 是否在vLLM推理时offload optimizer参数,默认为False。
+ - offload_model: 是否在vLLM推理时 offload 模型,默认为False。
+- num_iterations: 每条数据的更新次数,[GRPO论文](https://arxiv.org/abs/2402.03300)中的 $\mu$ 值,默认为1。
+- epsilon: clip 系数,默认为0.2。
+- epsilon_high: upper clip 系数,默认为None,设置后与epsilon共同构成[epsilon, epsilon_high]裁剪范围。
+- dynamic_sample:筛除group内奖励标准差为0的数据,额外采样新数据,默认为False。
+- max_resample_times:dynamic_sample设置下限制重采样次数,默认3次。
+- overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。
+- delta: [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291)中双侧 GRPO 上界裁剪值。若设置,建议大于 1 + epsilon。默认为None。
+- importance_sampling_level: 控制重要性采样比计算,可选项为 `token` 和 `sequence`,`token` 模式下保留原始的每个 token 的对数概率比,`sequence` 模式下则会对序列中所有有效 token 的对数概率比进行平均。[GSPO论文](https://arxiv.org/abs/2507.18071)中使用sequence级别计算来稳定训练,默认为`token`。
+- scale_rewards:指定奖励的缩放策略。可选值包括 `group`(按组内标准差缩放)、`batch`(按整个批次的标准差缩放)、`none`(不进行缩放)。在 ms-swift < 3.10 版本中,该参数为布尔类型,`true` 对应 `group`,`false` 对应 `none`。默认值与 `advantage_estimator` 绑定:`grpo` 对应 `group`,`rloo` 对应 `none`,`reinforce_plus_plus` 对应 `batch`。
+
+内置奖励函数参数参考[文档](../Instruction/Command-line-parameters.md#奖励函数参数)
+
## 导出参数
这里介绍`megatron export`的参数(需"ms-swift>=3.10"),若要使用`swift export`导出命令,请参考[ms-swift命令行参数文档](../Instruction/Command-line-parameters.md#导出参数)。`megatron export`相比`swift export`,支持分布式和多机导出。Megatron导出参数继承自Megatron参数和基本参数。
- 🔥to_mcore: HF格式权重转成Megatron格式。默认为False。
diff --git a/docs/source/Megatron-SWIFT/GRPO.md b/docs/source/Megatron-SWIFT/GRPO.md
new file mode 100644
index 0000000000..a8aa4df0e4
--- /dev/null
+++ b/docs/source/Megatron-SWIFT/GRPO.md
@@ -0,0 +1,61 @@
+# GRPO
+
+**版本依赖**:ms-swift >= 3.11
+
+如果你是首次使用 GRPO,请先参考 [GRPO文档](../Instruction/GRPO/GetStarted/GRPO.md)。
+
+Megatron GRPO 当前已支持以下功能:
+
+- **训练模式**:全参数训练与 LoRA 微调
+- **并行策略**:支持上下文并行(CP)、流水线并行(PP)、张量并行(TP)和专家并行(EP)
+- **推理加速**:支持 vLLM 的 colocate 模式和 server 模式
+- **模型支持**:兼容 Megatron Swift 中的 LLM 及 MLLM(多模态大模型)
+- **算法支持**:涵盖 swift GRPO 的大部分功能
+
+以下参数或功能将在后续版本中逐步支持:
+
+- **Entropy 相关配置**:如 `top_entropy_quantile`、`log_entropy`
+- **Reward Model / Reward Model Plugin**
+- **多轮 Rollout 调度机制**(`multi_turn_scheduler`):实现多轮对话策略优化
+- **优势估计器**(`advantage_estimator`):支持更复杂的策略梯度估计方法
+- **KL 散度计入奖励**(`kl_in_reward`)
+- **虚拟流水线并行**(VPP)
+- **参考模型同步更新**(`sync_ref_model`)
+- **Async Generate** (`async_generate`)
+- **num_iterations**
+- **日志同步 SwanLab**
+
+⚠️ 注意:以下参数在 Megatron GRPO 中不生效:
+
+- **`use_vllm`**:Megatron GRPO 暂不支持使用 PTEngine 进行 Rollout 推理。
+- **`move_model_batches`**:该参数专用于 DeepSpeed ZeRO-3 优化,在 Megatron 架构下无效。
+
+与 ms-swift GRPO 相同,Megatron GRPO batch size 相关的参数均以 **completion-level** 为单位,即表示模型生成的 completion 数量,而非 prompt 数量。
+
+#### 参数对比
+
+下表对比了 ms-swift 和 Megatron-SWIFT 中批量相关参数的对应关系:
+
+| ms-swift 参数 | Megatron-SWIFT 参数 | 说明 |
+|---------------|---------------------|------|
+| `per_device_train_batch_size` | `micro_batch_size` | 每张 GPU 的训练批次大小(completion-level) |
+| `gradient_accumulation_steps` | - | 梯度累积步数,在 Megatron-SWIFT 中已包含在 `global_batch_size` 的计算中 |
+| - | `global_batch_size` | 全局批次大小(completion-level)
**Megatron-SWIFT**: `micro_batch_size × dp_size × gradient_accumulation_steps`
**ms-swift**: `per_device_train_batch_size × world_size × gradient_accumulation_steps` |
+| `num_generations` | `num_generations` | 每个 prompt 生成的 completion 数量 |
+| `steps_per_generation` | `steps_per_generation` | Rollout 批次大小相对于训练批次大小的倍数
**注意**:在 ms-swift 中需为 `gradient_accumulation_steps` 的整数倍 |
+| `generation_batch_size` | `generation_batch_size` | Rollout 阶段的批次大小(completion-level),需为 `global_batch_size` 的整数倍 |
+
+以下公式用于计算 Megatron GRPO 中的批量:
+
+- **数据并行大小**:`dp_size = world_size / (TP × PP × CP)`
+- **全局批次大小**:`global_batch_size = micro_batch_size × dp_size × gradient_accumulation_steps`
+- **生成批次大小**:`generation_batch_size = global_batch_size × steps_per_generation`
+- **Rollout Prompt 数量**:`num_rollout_prompts = generation_batch_size / num_generations`
+- **训练 Prompt 数量**:`num_train_prompts = global_batch_size / num_generations`
+- **每个 DP group 的训练 Prompt 数量**:`num_prompts_per_dp_group = global_batch_size / num_generations / dp_size`
+
+注意:在 Megatron GRPO 中,每个 DP group 的训练 Prompt 数量须满足 `num_prompts_per_dp_group` 是 `micro_batch_size`的整数倍,以确保训练批次能够正确分配。
+
+更多参数请参考[命令行文档](./Command-line-parameters.md#grpo参数)
+
+训练脚本请参考[Megatron GRPO 脚本](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/grpo)
diff --git a/docs/source/Megatron-SWIFT/Multimodal-Model.md b/docs/source/Megatron-SWIFT/Multimodal-Model.md
index 9cc51732f7..8f51213211 100644
--- a/docs/source/Megatron-SWIFT/Multimodal-Model.md
+++ b/docs/source/Megatron-SWIFT/Multimodal-Model.md
@@ -1,6 +1,6 @@
# 多模态模型
-ms-swift引入了Megatron的并行技术来加速多模态大模型的训练。目前支持Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL等模型的CPT/SFT/DPO/KTO/RM。完整支持的模型可以参考[支持的模型与数据集文档](../Instruction/Supported-models-and-datasets.md)。
+ms-swift引入了Megatron的并行技术来加速多模态大模型的训练。目前支持Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL等模型的CPT/SFT/GRPO/DPO/KTO/RM。完整支持的模型可以参考[支持的模型与数据集文档](../Instruction/Supported-models-and-datasets.md)。
环境准备请参考Megatron-SWIFT的[快速开始文档](./Quick-start.md)。
diff --git a/docs/source/Megatron-SWIFT/Quick-start.md b/docs/source/Megatron-SWIFT/Quick-start.md
index 8c92e2b6b9..faff26ecec 100644
--- a/docs/source/Megatron-SWIFT/Quick-start.md
+++ b/docs/source/Megatron-SWIFT/Quick-start.md
@@ -8,6 +8,7 @@ ms-swift引入了Megatron的并行技术来加速大模型的训练,包括数
| ------ | ------ | ---- | ----- | ----- |
| 预训练| ✅ | ✅| ✅ | ✅ |
| 指令监督微调 | ✅ | ✅| ✅ | ✅ |
+| GRPO | ✅ | ✅| ✅ | ✅ |
| DPO | ✅ | ✅| ✅ | ✅ |
| KTO | ✅ | ✅| ✅ | ✅ |
| RM | ✅ | ✅| ✅ | ✅ |
diff --git a/docs/source/index.rst b/docs/source/index.rst
index c5a5fc08c8..f70a8a05c9 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -42,6 +42,7 @@ Swift DOCUMENTATION
Megatron-SWIFT/LoRA-Training.md
Megatron-SWIFT/Multimodal-Model.md
Megatron-SWIFT/Mcore-Bridge.md
+ Megatron-SWIFT/GRPO.md
.. toctree::
diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md
index 74634419e3..c7dd3d865a 100644
--- a/docs/source_en/Instruction/Command-line-parameters.md
+++ b/docs/source_en/Instruction/Command-line-parameters.md
@@ -577,9 +577,9 @@ The meanings of the following parameters can be referenced [here](https://huggin
- use_vllm: Whether to use vLLM as the infer_backend for GRPO generation, default is False.
- vllm_mode: Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `server` or `colocate`
- vllm_mode server parameter
- - vllm_server_base_url: Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` " "and `vllm_server_port` are ignored. Default is None.
- vllm_server_host: The host address of the vLLM server. Default is None.
- vllm_server_port: The service port of the vLLM server. Default is 8000.
+ - vllm_server_base_url: Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` " "and `vllm_server_port` are ignored. Default is None.
- vllm_server_timeout: The connection timeout for the vLLM server. Default is 240 seconds.
- vllm_server_pass_dataset: pass additional dataset information through to the vLLM server for multi-turn training.
- async_generate: Use async rollout to improve train speed. Note that rollout will use the model updated in the previous round when enabled. Multi-turn scenarios are not supported. Default is `false`.
@@ -592,7 +592,7 @@ The meanings of the following parameters can be referenced [here](https://huggin
- vllm_enable_prefix_caching: A pass-through parameter for vLLM, default is True.
- vllm_tensor_parallel_size: the tensor parallel size of vLLM engine, default is 1.
- vllm_enable_lora: Enable the vLLM engine to load LoRA adapters; defaults to False. Used to accelerate weight synchronization during LoRA training. See the [documentation](./GRPO/GetStarted/GRPO.md#weight-sync-acceleration) for details.
- - sleep_level: make vllm sleep when model is training. Options are 0 or 1, default is 0, no sleep
+ - sleep_level: make vllm sleep when model is training. Options are 0/1/2, default is 0, no sleep
- offload_optimizer: Whether to offload optimizer parameters during inference with vLLM. The default is `False`.
- offload_model: Whether to offload the model during inference with vLLM. The default is `False`.
- completion_length_limit_scope: Specifies the scope of the `max_completion_length` limit in multi-turn conversations.
@@ -607,7 +607,7 @@ The meanings of the following parameters can be referenced [here](https://huggin
- overlong_filter: Skip overlong truncated samples, which will not be included in loss calculation. Default is False.
The hyperparameters for the reward function can be found in the [Built-in Reward Functions section](#built-in-reward-functions).
- delta: Delta value for the upper clipping bound in two-sided GRPO. Recommended to be > 1 + epsilon. This method was introduced in the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291).
-- importance_sampling_level: Controls how the importance sampling ratio is computed. Options are `token` and `sequence`. In `token` mode, the raw per-token log-probability ratios are used. In `sequence` mode, the log-probability ratios of all valid tokens in the sequence are averaged to produce a single ratio per sequence. The [GSPO paper](https://www.arxiv.org/abs/2507.18071) uses sequence-level importance sampling to stabilize training. The default is `token`.
+- importance_sampling_level: Controls how the importance sampling ratio is computed. Options are `token` and `sequence`. In `token` mode, the raw per-token log-probability ratios are used. In `sequence` mode, the log-probability ratios of all valid tokens in the sequence are averaged to produce a single ratio per sequence. The [GSPO paper](https://arxiv.org/abs/2507.18071) uses sequence-level importance sampling to stabilize training. The default is `token`.
- advantage_estimator: Advantage estimator. Default is `grpo` (group-relative advantage). Options: `grpo`, [`rloo`](./GRPO/AdvancedResearch/RLOO.md), [`reinforce_plus_plus`](./GRPO/AdvancedResearch/REINFORCEPP.md).
- kl_in_reward: Controls where the KL regularization is applied. `false`: KL is a separate loss term. `true`: KL is subtracted from the reward. The default is bound to `advantage_estimator`: `false` for `grpo`, and `true` for `rloo` and `reinforce_plus_plus`.
- scale_rewards: Specifies the reward scaling strategy. Options: `group` (scale by intra-group std), `batch` (scale by batch-wide std), `none` (no scaling). In ms-swift < 3.10, this was a boolean where `true` corresponds to `group` and `false` to `none`. The default is bound to `advantage_estimator`: `group` for `grpo`, `none` for `rloo`, and `batch` for `reinforce_plus_plus`.
@@ -621,6 +621,8 @@ The hyperparameters for the reward function can be found in the [Built-in Reward
- log_entropy: Logs the entropy values during training. The default is False. For more information, refer to the [documentation](./GRPO/GetStarted/GRPO.md#logged-metrics).
+##### Reward function parameters
+Refer to the [documentation](./GRPO/DeveloperGuide/reward_function.md) for built-in reward functions.
cosine reward function arguments
- cosine_min_len_value_wrong (default: -0.5): Reward value corresponding to the minimum length when the answer is incorrect.
diff --git a/docs/source_en/Instruction/GRPO/AdvancedResearch/GSPO.md b/docs/source_en/Instruction/GRPO/AdvancedResearch/GSPO.md
index 2f8ec7ae54..03c67b3c6e 100644
--- a/docs/source_en/Instruction/GRPO/AdvancedResearch/GSPO.md
+++ b/docs/source_en/Instruction/GRPO/AdvancedResearch/GSPO.md
@@ -1,8 +1,8 @@
# Group Sequence Policy Optimization
-**Version Requirement**: ms-swift>=3.7
+**Version Requirement**: ms-swift>=3.8
-In [Group Sequence Policy Optimization](https://www.arxiv.org/abs/2507.18071), it is pointed out that GRPO computes importance sampling weights at the token level. However, this approach is problematic: since each token is only sampled once, it cannot realize effective distribution correction, and instead introduces high-variance noise during training, which can easily lead to unstable gradient estimates and even training collapse. Therefore, the paper argues that the unit of the objective function should be consistent with that of the reward. Since the reward is typically given at the sequence level (i.e., for the entire generated response), it is more reasonable to perform off-policy correction and optimization at the sequence level rather than the token level.
+In [Group Sequence Policy Optimization](https://arxiv.org/abs/2507.18071), it is pointed out that GRPO computes importance sampling weights at the token level. However, this approach is problematic: since each token is only sampled once, it cannot realize effective distribution correction, and instead introduces high-variance noise during training, which can easily lead to unstable gradient estimates and even training collapse. Therefore, the paper argues that the unit of the objective function should be consistent with that of the reward. Since the reward is typically given at the sequence level (i.e., for the entire generated response), it is more reasonable to perform off-policy correction and optimization at the sequence level rather than the token level.
Below are the three main strategies for computing importance sampling weights:
diff --git a/docs/source_en/Instruction/Use-tuners.md b/docs/source_en/Instruction/Use-tuners.md
index f960591893..d1b4f2cb1d 100644
--- a/docs/source_en/Instruction/Use-tuners.md
+++ b/docs/source_en/Instruction/Use-tuners.md
@@ -15,7 +15,7 @@ Tuners refer to additional structural components attached to a model, aimed at r
- Adapter: [Parameter-Efficient Transfer Learning for NLP](http://arxiv.org/abs/1902.00751)
- Vision Prompt Tuning: [Visual Prompt Tuning](https://arxiv.org/abs/2203.12119)
- Side: [Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks](https://arxiv.org/abs/1912.13503)
-- Res-Tuning: [Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone](https://arxiv.org/abs/2310.19859) < [arXiv](https://arxiv.org/abs/2310.19859) | [Project Page](https://res-tuning.github.io/) | [Usage](ResTuning.md) >
+- Res-Tuning: [Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone](https://arxiv.org/abs/2310.19859) < [arXiv](https://arxiv.org/abs/2310.19859) | [Project Page](https://res-tuning.github.io/) >
- Tuners provided by [PEFT](https://github.com/huggingface/peft), such as AdaLoRA, DoRA, Fourierft, etc.
## Interface List
diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md
index 8e0ef3085a..446916c1f7 100644
--- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md
+++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md
@@ -262,28 +262,6 @@ LoRA Training:
- lora_bias: Default is `'none'`. Available options: `'none'`, `'all'`. If you want all biases to be set as trainable, set this to `'all'`.
- use_rslora: Default is `False`. Whether to use `RS-LoRA`.
-**DPO Parameters**
-- ref_load: The loading path for the reference model. This must be provided when using DPO/KTO algorithms with full-parameter training. Defaults to `None`, which means it will be set to the same value as `load`.
-- ref_adapter_load: The path to load the ref_adapter weights, default is `None`. If you want to use LoRA weights generated from SFT for DPO, please use "ms-swift>=3.8" and set `--adapter_load sft_ckpt --ref_adapter_load sft_ckpt --finetune true` during training. For resuming training from a checkpoint in this scenario, set `--adapter_load rlhf_ckpt --ref_adapter_load sft_ckpt --finetune false`.
-- beta: Has the same meaning as in [TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig). It controls the degree of deviation from the reference model. A higher beta value indicates less deviation from the reference model. For the IPO loss function (`loss_type="ipo"`), beta is the regularization parameter as mentioned in the [paper](https://huggingface.co/papers/2310.12036). Default is 0.1.
-- 🔥rpo_alpha: A parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) that controls the weight of the NLL term (i.e., the SFT loss) in the loss function, where `loss = dpo_loss + rpo_alpha * sft_loss`. The paper recommends setting it to `1.`. The default value is `None`, meaning the SFT loss is not included by default.
- - **Note**: In "ms-swift<3.8", the default value was `1.`. Starting from "ms-swift>=3.8", the default has been changed to `None`.
-- reference_free: Whether to ignore the provided reference model and implicitly use a reference model that assigns equal probability to all responses. Default is `False`.
-- label_smoothing: Default is 0.
-- f_divergence_type: Default is `reverse_kl`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer) for possible values.
-- loss_type: Default is `'sigmoid'`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer#loss-functions) for possible values.
-
-**KTO Parameters**:
-- ref_load: same meaning as in DPO.
-- ref_adapter_load: same meaning as in DPO.
-- beta: parameter controlling the deviation from the ref_model. Higher `beta` means less deviation from the ref_model. Default is `0.1`.
-- loss_type: default is `'kto'`. See possible values in the TRL docs: https://huggingface.co/docs/trl/main/en/kto_trainer#trl.KTOConfig.loss_type.
-- desirable_weight: factor to weight desirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`.
-- undesirable_weight: factor to weight undesirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`.
-
-**RM Parameters**:
-- center_rewards_coefficient: A coefficient used in reward model (RM) training to incentivize the model to output rewards with zero mean. See this [paper](https://huggingface.co/papers/2312.09244) for details. Recommended value: 0.01.
-
**Mcore-Bridge Parameters**
- 🔥load_safetensors: Defaults to False. Whether to load weights directly from safetensors.
@@ -291,7 +269,7 @@ LoRA Training:
- model: The model_id or model_path of safetensors weights. Defaults to None.
- model_type: Model type. For details, refer to [ms-swift command-line parameters documentation](../Instruction/Command-line-parameters.md).
- adapters: adapter_id or adapter_path of LoRA incremental weights in safetensors format. Default is `[]`.
-- ref_model: model_id or model_path of ref_model safetensors weights. Required when using DPO or KTO algorithms with full-parameter training. Default is None, set to `--model`.
+- ref_model: model_id or model_path of ref_model safetensors weights. Required when using DPO/GRPO/KTO algorithms with full-parameter training. Default is None, set to `--model`.
- ref_adapters: List of adapter_id or adapter_path of ref_adapters safetensors weights (currently only supports length of 1). Default is `[]`.
- use_hf: Controls whether to use ModelScope or HuggingFace for model download, dataset download, and model push. Default is False, using ModelScope.
- hub_token: Hub token. ModelScope hub token can be found [here](https://modelscope.cn/my/myaccesstoken). Default is None.
@@ -313,7 +291,7 @@ Megatron training parameters are inherited from Megatron parameters and basic pa
- Typically used together with `--freeze_vit false` and `--freeze_aligner false`.
- aligner_lr: Specifies the learning rate for the aligner module in multimodal models. Default is `None`, same as `learning_rate`.
- gradient_checkpointing_kwargs: Arguments passed to `torch.utils.checkpoint`. For example: set `--gradient_checkpointing_kwargs '{"use_reentrant": false}'`. Defaults to `None`. This parameter only takes effect when `vit_gradient_checkpointing` is enabled.
-- 🔥packing: Whether to use sequence packing to improve computational efficiency (achieving better load balancing across nodes and processes, and higher GPU utilization), at the cost of additional preprocessing time, while also stabilizing GPU memory usage. Defaults to `False`. Currently supported for CPT, SFT, DPO, KTO and RM.
+- 🔥packing: Whether to use sequence packing to improve computational efficiency (achieving better load balancing across nodes and processes, and higher GPU utilization), at the cost of additional preprocessing time, while also stabilizing GPU memory usage. Defaults to `False`. Currently supported for CPT, SFT, GRPO, DPO, KTO and RM.
- Note: **Sequences within the same batch remain mutually invisible**, except for Qwen3-Next.
- Note: **Packing will reduce the number of dataset samples. Please adjust global_batch_size and learning rate accordingly**.
- packing_length: the length to use for packing. Defaults to None, in which case it is set to max_length.
@@ -337,11 +315,83 @@ Megatron training parameters are inherited from Megatron parameters and basic pa
In addition to inheriting the training parameters, the following parameters are also supported:
-- 🔥rlhf_type: Default is 'dpo'. Currently, 'dpo', 'kto', and 'rm' are available.
+- 🔥rlhf_type: Default is 'dpo'. Currently, 'dpo', 'grpo', 'kto', and 'rm' are available.
- loss_scale: Overrides the `loss_scale` in [basic parameters](../Instruction/Command-line-parameters.md). Default is 'last_round'.
- calculate_per_token_loss: Overrides the Megatron parameter. Default is False.
+### DPO Parameters
+
+- ref_load: The loading path for the reference model. This must be provided when using DPO/GRPO/KTO algorithms with full-parameter training. Defaults to `None`, which means it will be set to the same value as `load`.
+- ref_adapter_load: The path to load the ref_adapter weights, default is `None`. If you want to use LoRA weights generated from SFT for DPO, please use "ms-swift>=3.8" and set `--adapter_load sft_ckpt --ref_adapter_load sft_ckpt --finetune true` during training. For resuming training from a checkpoint in this scenario, set `--adapter_load rlhf_ckpt --ref_adapter_load sft_ckpt --finetune false`.
+- beta: Has the same meaning as in [TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig). It controls the degree of deviation from the reference model. A higher beta value indicates less deviation from the reference model. For the IPO loss function (`loss_type="ipo"`), beta is the regularization parameter as mentioned in the [paper](https://huggingface.co/papers/2310.12036). Default is 0.1.
+- 🔥rpo_alpha: A parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) that controls the weight of the NLL term (i.e., the SFT loss) in the loss function, where `loss = dpo_loss + rpo_alpha * sft_loss`. The paper recommends setting it to `1.`. The default value is `None`, meaning the SFT loss is not included by default.
+ - **Note**: In "ms-swift<3.8", the default value was `1.`. Starting from "ms-swift>=3.8", the default has been changed to `None`.
+- reference_free: Whether to ignore the provided reference model and implicitly use a reference model that assigns equal probability to all responses. Default is `False`.
+- label_smoothing: Default is 0.
+- f_divergence_type: Default is `reverse_kl`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer) for possible values.
+- loss_type: Default is `'sigmoid'`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer#loss-functions) for possible values.
+
+### KTO Parameters
+
+- ref_load: same meaning as in DPO.
+- ref_adapter_load: same meaning as in DPO.
+- beta: parameter controlling the deviation from the ref_model. Higher `beta` means less deviation from the ref_model. Default is `0.1`.
+- loss_type: default is `'kto'`. See possible values in the TRL docs: https://huggingface.co/docs/trl/main/en/kto_trainer#trl.KTOConfig.loss_type.
+- desirable_weight: factor to weight desirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`.
+- undesirable_weight: factor to weight undesirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`.
+
+### RM Parameters
+
+- center_rewards_coefficient: A coefficient used in reward model (RM) training to incentivize the model to output rewards with zero mean. See this [paper](https://huggingface.co/papers/2312.09244) for details. Recommended value: 0.01.
+
+### GRPO Parameters
+
+- ref_load: Same meaning as in DPO.
+- ref_adapter_load: Same meaning as in DPO.
+- beta: KL regularization coefficient, default is 0.04. When set to 0, the ref model is not loaded.
+- micro_batch_size: Batch size per device, default is 1.
+- global_batch_size: Total batch size, equivalent to `micro_batch_size * data parallel size * gradient accumulation steps`. Default is 16.
+- steps_per_generation: Number of optimization steps per generation round, i.e., the ratio of sampling batch size to global_batch_size. Default is 1.
+- generation_batch_size: Sampling batch size, must be a multiple of global_batch_size. Default equals global_batch_size * steps_per_generation.
+- num_generations: Number of samples per prompt, the G value in the paper, default is 8.
+- reward_funcs: GRPO algorithm reward functions. Options include `accuracy`, `format`, `cosine`, `repetition`, and `soft_overlong`. See swift/plugin/orm.py. You can also customize your own reward functions in the plugin. Default is `[]`.
+- reward_weights: Weights for each reward function. Must match the total number of reward functions and reward models. Default is None, meaning all rewards have equal weights of `1.0`.
+ - Tip: If GRPO training includes `--reward_model`, it is added at the end of the reward functions.
+- loss_type: Loss normalization type. Options are `['grpo', 'bnpo', 'dr_grpo']`. Default is `'grpo'`. See this [PR](https://github.com/huggingface/trl/pull/3256#discussion_r2033213348) for details.
+- log_completions: Whether to log model-generated content during training. Default is False.
+- vllm_mode: vLLM integration mode. Options are `server` and `colocate`. Server mode uses the vLLM server launched by `swift rollout` for sampling, while colocate mode deploys vLLM within the program. When using server mode:
+- vllm_mode server parameters:
+ - vllm_server_host: vLLM server host address. Default is None.
+ - vllm_server_port: vLLM server port. Default is 8000.
+ - vllm_server_base_url: Base URL of the vLLM server (e.g., http://local_host:8000). Default is None. When set, host and port settings are ignored.
+ - vllm_server_timeout: Timeout for connecting to the vLLM server. Default is 240s.
+ - vllm_server_pass_dataset: Pass additional dataset information to the vLLM server for multi-round training.
+ - async_generate: Asynchronous rollout to improve training speed. Note: When enabled, sampling uses the model from the previous round update, and multi-round scenarios are not supported. Default is `false`.
+ - SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE: Environment variable for controlling the bucket size during weight synchronization. Applicable to full-parameter training in Server Mode. Unit is MB, default value is 512 MB.
+- vllm_mode colocate parameters (for more parameter support, refer to [vLLM parameters](#vllm-parameters)):
+ - vllm_gpu_memory_utilization: vLLM passthrough parameter. Default is 0.9.
+ - vllm_max_model_len: vLLM passthrough parameter. Default is None.
+ - vllm_enforce_eager: vLLM passthrough parameter. Default is False.
+ - vllm_limit_mm_per_prompt: vLLM passthrough parameter. Default is None.
+ - vllm_enable_prefix_caching: vLLM passthrough parameter. Default is True.
+ - vllm_tensor_parallel_size: Tensor parallel size. Default is `1`.
+ - vllm_enable_lora: Support loading LoRA adapters in the vLLM Engine. Default is False. Used to accelerate weight synchronization in LoRA training. See [documentation](../Instruction/GRPO/GetStarted/GRPO.md#weight-synchronization-acceleration) for details.
+ - sleep_level: Release vLLM GPU memory during training. Options are `[0, 1, 2]`. Default is 0, meaning no release.
+ - offload_optimizer: Whether to offload optimizer parameters during vLLM inference. Default is False.
+ - offload_model: Whether to offload the model during vLLM inference. Default is False.
+- num_iterations: Number of updates per data sample, the $\mu$ value in the [GRPO paper](https://arxiv.org/abs/2402.03300). Default is 1.
+- epsilon: Clip coefficient. Default is 0.2.
+- epsilon_high: Upper clip coefficient. Default is None. When set, together with epsilon, forms the clipping range `[epsilon, epsilon_high]`.
+- dynamic_sample: Filter out data with zero reward standard deviation within groups and sample additional new data. Default is False.
+- max_resample_times: Limit the number of resampling times under dynamic_sample setting. Default is 3.
+- overlong_filter: Skip overlong truncated samples, which do not participate in loss calculation. Default is False.
+- delta: Bilateral GRPO upper bound clipping value from the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291). If set, it is recommended to be greater than 1 + epsilon. Default is None.
+- importance_sampling_level: Controls importance sampling ratio calculation. Options are `token` and `sequence`. In `token` mode, the original log probability ratio for each token is preserved. In `sequence` mode, the log probability ratios of all valid tokens in the sequence are averaged. The [GSPO paper](https://arxiv.org/abs/2507.18071) uses sequence-level calculation to stabilize training. Default is `token`.
+- scale_rewards: Specifies the reward scaling strategy. Options include `group` (scale by within-group standard deviation), `batch` (scale by batch-wide standard deviation), and `none` (no scaling). In ms-swift < 3.10, this parameter is boolean, where `true` corresponds to `group` and `false` corresponds to `none`. The default value is bound to `advantage_estimator`: `grpo` corresponds to `group`, `rloo` corresponds to `none`, and `reinforce_plus_plus` corresponds to `batch`.
+
+Built-in reward function parameters refer to the [documentation](../Instruction/Command-line-parameters.md#reward-function-parameters).
+
## Export Parameters
This section introduces the parameters for `megatron export` (requires "ms-swift>=3.10"). To use the `swift export` command for exporting, please refer to the [ms-swift Command Line Parameters Documentation](../Instruction/Command-line-parameters.md#export-arguments). Compared to `swift export`, `megatron export` supports distributed and multi-node exporting. Megatron export parameters inherit from Megatron parameters and basic parameters.
diff --git a/docs/source_en/Megatron-SWIFT/GRPO.md b/docs/source_en/Megatron-SWIFT/GRPO.md
new file mode 100644
index 0000000000..3fa9dfb58d
--- /dev/null
+++ b/docs/source_en/Megatron-SWIFT/GRPO.md
@@ -0,0 +1,61 @@
+# Megatron GRPO
+
+**Version Requirement**: ms-swift >= 3.11
+
+If you are new to GRPO, please refer to the [GRPO documentation](../Instruction/GRPO/GetStarted/GRPO.md) first.
+
+Megatron GRPO currently supports the following features:
+
+- **Training Modes**: Full parameter training and LoRA fine-tuning
+- **Parallelism Strategies**: Context Parallelism (CP), Pipeline Parallelism (PP), Tensor Parallelism (TP), and Expert Parallelism (EP)
+- **Inference Acceleration**: vLLM colocate mode and server mode
+- **Model Support**: Compatible with LLMs and MLLMs (multimodal large models) in Megatron Swift
+- **Algorithm Support**: Covers most features of Swift GRPO
+
+The following parameters or features will be gradually supported in future versions:
+
+- **Entropy-related Configuration**: e.g., `top_entropy_quantile`, `log_entropy`
+- **Reward Model / Reward Model Plugin**
+- **Multi-turn Rollout Scheduling** (`multi_turn_scheduler`): Multi-turn conversation policy optimization
+- **Advantage Estimator** (`advantage_estimator`): Support for more complex policy gradient estimation methods
+- **KL Divergence in Reward** (`kl_in_reward`)
+- **Virtual Pipeline Parallelism** (VPP)
+- **Reference Model Synchronization** (`sync_ref_model`)
+- **Async Generate** (`async_generate`)
+- **num_iterations**
+- **SwanLab Logging Integration**
+
+⚠️ **Note**: The following parameters are not effective in Megatron GRPO:
+
+- **`use_vllm`**: Megatron GRPO does not support using PTEngine for Rollout inference.
+- **`move_model_batches`**: This parameter is specific to DeepSpeed ZeRO-3 optimization and is invalid in the Megatron architecture.
+
+Similar to ms-swift GRPO, all batch size-related parameters in Megatron GRPO are at the **completion-level**, meaning they represent the number of completions generated by the model, not the number of prompts.
+
+#### Parameter Comparison
+
+The following table compares the batch-related parameters between ms-swift and Megatron-SWIFT:
+
+| ms-swift Parameter | Megatron-SWIFT Parameter | Description |
+|-------------------|--------------------------|-------------|
+| `per_device_train_batch_size` | `micro_batch_size` | Training batch size per GPU (completion-level) |
+| `gradient_accumulation_steps` | - | Gradient accumulation steps, already included in `global_batch_size` calculation in Megatron-SWIFT |
+| - | `global_batch_size` | Global batch size (completion-level)
**Megatron-SWIFT**: `micro_batch_size × dp_size × gradient_accumulation_steps`
**ms-swift**: `per_device_train_batch_size × world_size × gradient_accumulation_steps` |
+| `num_generations` | `num_generations` | Number of completions generated per prompt |
+| `steps_per_generation` | `steps_per_generation` | Ratio of Rollout batch size to training batch size
**Note**: In ms-swift, must be an integer multiple of `gradient_accumulation_steps` |
+| `generation_batch_size` | `generation_batch_size` | Batch size during Rollout phase (completion-level), must be an integer multiple of `global_batch_size` |
+
+The following formulas are used to calculate batch sizes in Megatron GRPO:
+
+- **Data Parallel Size**: `dp_size = world_size / (TP × PP × CP)`
+- **Global Batch Size**: `global_batch_size = micro_batch_size × dp_size × gradient_accumulation_steps`
+- **Generation Batch Size**: `generation_batch_size = global_batch_size × steps_per_generation`
+- **Rollout Prompt Count**: `num_rollout_prompts = generation_batch_size / num_generations`
+- **Training Prompt Count**: `num_train_prompts = global_batch_size / num_generations`
+- **Training Prompt Count per DP Group**: `num_prompts_per_dp_group = global_batch_size / num_generations / dp_size`
+
+**Note**: In Megatron GRPO, the training prompt count per DP group must satisfy that `num_prompts_per_dp_group` is an integer multiple of `micro_batch_size` to ensure proper batch allocation during training.
+
+For more parameters, please refer to the [Command-line Parameters documentation](./Command-line-parameters.md#grpo-parameters).
+
+For training scripts, please refer to [Megatron GRPO Scripts](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/grpo).
diff --git a/docs/source_en/Megatron-SWIFT/Multimodal-Model.md b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md
index 9f339cc547..d3d96dde1f 100644
--- a/docs/source_en/Megatron-SWIFT/Multimodal-Model.md
+++ b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md
@@ -1,6 +1,6 @@
# Multimodal Models
-ms-swift introduces Megatron's parallelization techniques to accelerate the training of large multimodal models. Currently, it supports CPT/SFT/DPO/KTO/RM for models such as Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL. For a complete list of supported models, please refer to the [Supported Models and Datasets documentation](../Instruction/Supported-models-and-datasets.md).
+ms-swift introduces Megatron's parallelization techniques to accelerate the training of large multimodal models. Currently, it supports CPT/SFT/GRPO/DPO/KTO/RM for models such as Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL. For a complete list of supported models, please refer to the [Supported Models and Datasets documentation](../Instruction/Supported-models-and-datasets.md).
For environment setup, please refer to the Megatron-SWIFT [Quick Start guide](./Quick-start.md).
diff --git a/docs/source_en/Megatron-SWIFT/Quick-start.md b/docs/source_en/Megatron-SWIFT/Quick-start.md
index ed46f0471f..94123c8c4e 100644
--- a/docs/source_en/Megatron-SWIFT/Quick-start.md
+++ b/docs/source_en/Megatron-SWIFT/Quick-start.md
@@ -7,9 +7,10 @@ ms-swift incorporates Megatron's parallelization techniques to accelerate the tr
| ---------------------------------- | -------------- | ---- | ---- | ---------- |
| Pretraining | ✅ | ✅ | ✅ | ✅ |
| Instruction-supervised fine-tuning | ✅ | ✅ | ✅ | ✅ |
+| GRPO | ✅ | ✅ | ✅ | ✅ |
| DPO | ✅ | ✅ | ✅ | ✅ |
| KTO | ✅ | ✅ | ✅ | ✅ |
-| RM | ✅ | ✅ | ✅ | ✅ |
+| RM | ✅ | ✅ | ✅ | ✅ |
| Classification tasks | ✅ | ✅ | ✅ | ✅ |
## Environment Setup
diff --git a/docs/source_en/index.rst b/docs/source_en/index.rst
index c5a5fc08c8..f70a8a05c9 100644
--- a/docs/source_en/index.rst
+++ b/docs/source_en/index.rst
@@ -42,6 +42,7 @@ Swift DOCUMENTATION
Megatron-SWIFT/LoRA-Training.md
Megatron-SWIFT/Multimodal-Model.md
Megatron-SWIFT/Mcore-Bridge.md
+ Megatron-SWIFT/GRPO.md
.. toctree::
diff --git a/examples/megatron/grpo/dense_colocate.sh b/examples/megatron/grpo/dense_colocate.sh
new file mode 100644
index 0000000000..4cbd7cafbb
--- /dev/null
+++ b/examples/megatron/grpo/dense_colocate.sh
@@ -0,0 +1,65 @@
+# DP size = world_size // (context_parallel_size * tensor_model_parallel_size * pipeline_model_parallel_size)
+# = 8 // (1 * 1 * 1) = 8
+
+# NOTE: global_batch_size and micro_batch_size are completion-level
+# global_batch_size = micro_batch_size * DP size * gradient_accumulation_steps (128)
+# generation_batch_size = global_batch_size * steps_per_generation (128 * 4 = 512)
+# num_of_prompt_to_rollout = generation_batch_size / num_generations (512 / 8 = 64)
+# num_of_prompt_to_train = generation_batch_size / num_generations (128 / 8 = 16)
+
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
+NPROC_PER_NODE=8 \
+MAX_PIXELS=602112 \
+MASTER_PORT=29600 \
+megatron rlhf \
+ --rlhf_type grpo \
+ --model Qwen/Qwen2.5-VL-3B-Instruct \
+ --load_safetensors true \
+ --save_safetensors true \
+ --context_parallel_size 1 \
+ --tensor_model_parallel_size 1 \
+ --pipeline_model_parallel_size 1 \
+ --dataset AI-ModelScope/clevr_cogen_a_train#10000 \
+ --max_epochs 1 \
+ --global_batch_size 128 \
+ --micro_batch_size 4 \
+ --steps_per_generation 4 \
+ --num_generations 8 \
+ --external_plugins examples/train/grpo/plugin/plugin.py \
+ --reward_funcs external_r1v_acc format \
+ --use_vllm true \
+ --vllm_mode colocate \
+ --vllm_gpu_memory_utilization 0.7 \
+ --vllm_max_model_len 10240 \
+ --max_length 8192 \
+ --max_completion_length 2048 \
+ --train_type full \
+ --lr 1e-6 \
+ --bf16 true \
+ --beta 0.001 \
+ --importance_sampling_level token \
+ --epsilon 0.2 \
+ --epsilon_high 0.2 \
+ --dynamic_sample false \
+ --overlong_filter true \
+ --loss_type grpo \
+ --sleep_level 2 \
+ --offload_model true \
+ --offload_optimizer true \
+ --log_interval 1 \
+ --recompute_granularity selective \
+ --finetune \
+ --num_workers 8 \
+ --dataset_num_proc 8 \
+ --no_save_optim \
+ --no_save_rng \
+ --attention_backend flash \
+ --temperature 1.0 \
+ --system examples/train/grpo/prompt.txt \
+ --padding_free true \
+ --log_completions true \
+ --wandb_project megatron_swift \
+ --wandb_exp_name megatron_grpo \
+ --train_iters 100 \
+ --eval_interval 1000 \
+ --save_interval 1000
diff --git a/examples/megatron/grpo/dense_server.sh b/examples/megatron/grpo/dense_server.sh
new file mode 100644
index 0000000000..ee702800e2
--- /dev/null
+++ b/examples/megatron/grpo/dense_server.sh
@@ -0,0 +1,72 @@
+# MAX_PIXELS=602112 \
+# CUDA_VISIBLE_DEVICES=6,7 \
+# swift rollout \
+# --model Qwen/Qwen2.5-VL-3B-Instruct \
+# --vllm_data_parallel_size 2 \
+# --vllm_max_model_len 10240
+
+# DP size = world_size // (context_parallel_size * tensor_model_parallel_size * pipeline_model_parallel_size)
+# = 6 // (1 * 1 * 1) = 6
+
+# NOTE: global_batch_size and micro_batch_size are completion-level
+# global_batch_size = micro_batch_size * DP size * gradient_accumulation_steps (96)
+# generation_batch_size = global_batch_size * steps_per_generation (96 * 4 = 384)
+# num_of_prompt_to_rollout = generation_batch_size / num_generations (384 / 8 = 48)
+# num_of_prompt_to_train = generation_batch_size / num_generations (96 / 8 = 12)
+
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 \
+NPROC_PER_NODE=6 \
+MAX_PIXELS=602112 \
+MASTER_PORT=29600 \
+megatron rlhf \
+ --rlhf_type grpo \
+ --model Qwen/Qwen2.5-VL-3B-Instruct \
+ --load_safetensors true \
+ --save_safetensors true \
+ --context_parallel_size 1 \
+ --tensor_model_parallel_size 1 \
+ --pipeline_model_parallel_size 1 \
+ --dataset AI-ModelScope/clevr_cogen_a_train#10000 \
+ --max_epochs 1 \
+ --global_batch_size 96 \
+ --micro_batch_size 4 \
+ --steps_per_generation 4 \
+ --num_generations 8 \
+ --external_plugins examples/train/grpo/plugin/plugin.py \
+ --reward_funcs external_r1v_acc format \
+ --use_vllm true \
+ --vllm_mode server \
+ --vllm_server_host 127.0.0.1 \
+ --vllm_server_port 8000 \
+ --max_length 8192 \
+ --max_completion_length 2048 \
+ --train_type full \
+ --lr 1e-6 \
+ --bf16 true \
+ --beta 0.001 \
+ --importance_sampling_level token \
+ --epsilon 0.2 \
+ --epsilon_high 0.2 \
+ --dynamic_sample false \
+ --overlong_filter true \
+ --loss_type grpo \
+ --sleep_level 2 \
+ --offload_model true \
+ --offload_optimizer true \
+ --log_interval 1 \
+ --recompute_granularity selective \
+ --finetune \
+ --num_workers 8 \
+ --dataset_num_proc 8 \
+ --no_save_optim \
+ --no_save_rng \
+ --attention_backend flash \
+ --temperature 1.0 \
+ --system examples/train/grpo/prompt.txt \
+ --padding_free true \
+ --log_completions true \
+ --wandb_project megatron_swift \
+ --wandb_exp_name megatron_grpo \
+ --train_iters 100 \
+ --eval_interval 1000 \
+ --save_interval 1000
diff --git a/examples/megatron/grpo/moe_colocate_full.sh b/examples/megatron/grpo/moe_colocate_full.sh
new file mode 100644
index 0000000000..7b66688fd9
--- /dev/null
+++ b/examples/megatron/grpo/moe_colocate_full.sh
@@ -0,0 +1,55 @@
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
+NPROC_PER_NODE=8 \
+PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
+megatron rlhf \
+ --rlhf_type grpo \
+ --model Qwen/Qwen3-30B-A3B-Instruct-2507 \
+ --load_safetensors true \
+ --save_safetensors true \
+ --context_parallel_size 1 \
+ --tensor_model_parallel_size 4 \
+ --expert_model_parallel_size 4 \
+ --pipeline_model_parallel_size 2 \
+ --dataset open-r1/DAPO-Math-17k-Processed \
+ --max_epochs 1 \
+ --global_batch_size 8 \
+ --micro_batch_size 1 \
+ --steps_per_generation 1 \
+ --num_generations 8 \
+ --reward_funcs accuracy format \
+ --use_vllm true \
+ --vllm_mode colocate \
+ --vllm_gpu_memory_utilization 0.4 \
+ --vllm_tensor_parallel_size 8 \
+ --vllm_max_model_len 16384 \
+ --max_length 8192 \
+ --max_completion_length 8192 \
+ --train_type full \
+ --lr 1e-6 \
+ --bf16 true \
+ --beta 0.00 \
+ --importance_sampling_level sequence \
+ --epsilon 3e-4 \
+ --epsilon_high 4e-4 \
+ --dynamic_sample false \
+ --overlong_filter true \
+ --loss_type grpo \
+ --sleep_level 2 \
+ --offload_model true \
+ --offload_optimizer true \
+ --optimizer_cpu_offload true \
+ --use_precision_aware_optimizer \
+ --log_interval 1 \
+ --recompute_granularity selective \
+ --finetune \
+ --num_workers 8 \
+ --dataset_num_proc 8 \
+ --no_save_optim \
+ --no_save_rng \
+ --attention_backend flash \
+ --temperature 1.0 \
+ --padding_free true \
+ --sequence_parallel true \
+ --log_completions true \
+ --wandb_project megatron_swift \
+ --wandb_exp_name megatron_grpo \
diff --git a/examples/megatron/grpo/moe_colocate_lora.sh b/examples/megatron/grpo/moe_colocate_lora.sh
new file mode 100644
index 0000000000..361a233e6c
--- /dev/null
+++ b/examples/megatron/grpo/moe_colocate_lora.sh
@@ -0,0 +1,53 @@
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
+NPROC_PER_NODE=8 \
+PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
+megatron rlhf \
+ --rlhf_type grpo \
+ --model Qwen/Qwen3-30B-A3B-Instruct-2507 \
+ --load_safetensors true \
+ --save_safetensors true \
+ --context_parallel_size 2 \
+ --tensor_model_parallel_size 2 \
+ --expert_model_parallel_size 4 \
+ --pipeline_model_parallel_size 2 \
+ --dataset open-r1/DAPO-Math-17k-Processed \
+ --max_epochs 1 \
+ --global_batch_size 64 \
+ --micro_batch_size 2 \
+ --steps_per_generation 2 \
+ --num_generations 8 \
+ --reward_funcs accuracy format \
+ --use_vllm true \
+ --vllm_mode colocate \
+ --vllm_gpu_memory_utilization 0.3 \
+ --vllm_tensor_parallel_size 4 \
+ --vllm_max_model_len 16384 \
+ --max_length 8192 \
+ --max_completion_length 8192 \
+ --train_type lora \
+ --lr 5e-5 \
+ --bf16 true \
+ --beta 0.00 \
+ --importance_sampling_level sequence \
+ --epsilon 3e-4 \
+ --epsilon_high 4e-4 \
+ --dynamic_sample false \
+ --overlong_filter true \
+ --loss_type grpo \
+ --sleep_level 2 \
+ --offload_model true \
+ --offload_optimizer true \
+ --log_interval 1 \
+ --recompute_granularity selective \
+ --finetune \
+ --num_workers 8 \
+ --dataset_num_proc 8 \
+ --no_save_optim \
+ --no_save_rng \
+ --attention_backend flash \
+ --temperature 1.0 \
+ --padding_free true \
+ --sequence_parallel true \
+ --log_completions true \
+ --wandb_project megatron_swift \
+ --wandb_exp_name megatron_grpo \
diff --git a/swift/llm/dataset/dataset/llm.py b/swift/llm/dataset/dataset/llm.py
index 2cba486e3f..588a4ee621 100644
--- a/swift/llm/dataset/dataset/llm.py
+++ b/swift/llm/dataset/dataset/llm.py
@@ -925,3 +925,10 @@ def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
],
dataset_name='self-cognition',
tags=['chat', 'self-cognition', '🔥']))
+
+register_dataset(
+ DatasetMeta(
+ ms_dataset_id='open-r1/DAPO-Math-17k-Processed',
+ hf_dataset_id='open-r1/DAPO-Math-17k-Processed',
+ subsets=['all'],
+ tags=['math', 'rlvr']))
diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py
index fce20eb7d2..12afad29af 100644
--- a/swift/llm/template/base.py
+++ b/swift/llm/template/base.py
@@ -1275,6 +1275,8 @@ def _handle_megatron_cp(self, encoded: Dict[str, Any]) -> None:
cp_size = self.sequence_parallel_size
if not self.use_megatron or cp_size == 1:
return
+ if self.mode == 'vllm': # skip for megatron grpo rollout
+ return
input_ids = encoded['input_ids']
padding_len = math.ceil(len(input_ids) / (cp_size * 2)) * (cp_size * 2) - len(input_ids)
input_ids += [self.tokenizer.pad_token_id] * padding_len
diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py
index 556a631f36..4609df7948 100644
--- a/swift/megatron/argument/megatron_args.py
+++ b/swift/megatron/argument/megatron_args.py
@@ -17,7 +17,7 @@
@dataclass
class RLHFMegatronArgumentsMixin:
- rlhf_type: Literal['dpo', 'kto', 'rm'] = None
+ rlhf_type: Literal['dpo', 'kto', 'grpo', 'rm'] = None
ref_load: Optional[str] = None
ref_adapter_load: Optional[str] = None
@@ -36,6 +36,97 @@ class RLHFMegatronArgumentsMixin:
# rm
center_rewards_coefficient: Optional[float] = None
+ # grpo
+ generation_batch_size: Optional[int] = None
+ steps_per_generation: Optional[int] = None
+ num_generations: int = 8
+ max_completion_length: int = 512
+ # GSPO https://arxiv.org/abs/2507.18071
+ importance_sampling_level: Literal['token', 'sequence', 'sequence_token'] = 'token'
+
+ epsilon: float = 0.2
+ epsilon_high: Optional[float] = None
+ delta: Optional[float] = None
+ top_k: int = 50
+ top_p: float = 0.9
+ repetition_penalty: float = 1.
+ use_vllm: bool = True
+ vllm_mode: Literal['server', 'colocate'] = 'colocate'
+
+ vllm_enable_prefix_caching: bool = True
+ vllm_gpu_memory_utilization: float = 0.9
+ vllm_tensor_parallel_size: int = 1
+ vllm_max_model_len: Optional[int] = None
+ vllm_enforce_eager: bool = False
+ vllm_limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 5, "video": 2}'
+ vllm_disable_cascade_attn: bool = False
+ sleep_level: Literal[0, 1, 2] = 0
+ offload_optimizer: bool = False
+ offload_model: bool = False
+
+ vllm_server_base_url: Optional[List[str]] = None
+ vllm_server_host: Optional[List[str]] = None
+ vllm_server_port: List[int] = field(default_factory=lambda: [8000])
+ vllm_server_timeout: float = 240.0
+
+ reward_funcs: List[str] = field(default_factory=list)
+ reward_weights: List[float] = None
+ # see details in swift/plugin/orm.py
+ # cosine reward, https://arxiv.org/abs/2502.03373
+ cosine_min_len_value_wrong: float = -0.5 # r^w_0 in paper, Reward for wrong answers with zero completion length.
+ cosine_max_len_value_wrong: float = 0.0 # r^w_L in paper, Reward for wrong answers with max completion length.
+ cosine_min_len_value_correct: float = 1.0 # r^c_0 in paper, Reward for correct answers with zero completion length.
+ cosine_max_len_value_correct: float = 0.5 # r^c_L in paper, Reward for correct answers with max completion length.
+ cosine_max_len: Optional[int] = None # Lmax in paper, default equal to max_completion_length
+ # repetition penalty, https://arxiv.org/abs/2502.03373
+ repetition_n_grams: int = 3
+ repetition_max_penalty: float = -1.0
+ # soft_overlong, https://arxiv.org/abs/2503.14476
+ soft_max_length: Optional[int] = None
+ soft_cache_length: Optional[int] = None
+ # DAPO, https://arxiv.org/abs/2503.14476
+ dynamic_sample: bool = False
+ max_resample_times: int = 3
+ overlong_filter: bool = False
+
+ # Dr. GRPO, https://arxiv.org/abs/2503.20783
+ scale_rewards: Literal['none', 'group', 'batch'] = 'group'
+
+ wandb_log_unique_prompts: Optional[bool] = None
+ log_completions: bool = False
+
+ # ─────────────────────────── Not Supported Yet ───────────────────────────
+ # RLOO / REINFORCE++
+ advantage_estimator: Literal['grpo', 'rloo', 'reinforce_plus_plus'] = 'grpo'
+ kl_in_reward: bool = False
+ # reward model
+ reward_model: Optional[List[str]] = None
+ reward_model_plugin: Optional[List[str]] = None
+ # sync ref model
+ sync_ref_model: bool = False
+ ref_model_sync_steps: int = 512
+ ref_model_mixup_alpha: float = 0.6
+
+ async_generate: bool = False
+
+ move_model_batches: Optional[int] = None
+
+ # multi turn
+ multi_turn_scheduler: Optional[str] = None
+ max_turns: Optional[int] = None
+ completion_length_limit_scope: Literal['total', 'per_round'] = 'per_round'
+ vllm_server_pass_dataset: bool = False
+
+ # entropy
+ log_entropy: bool = False
+ # Beyond the 80/20 Rule, https://arxiv.org/abs/2506.01939
+ top_entropy_quantile: float = 1.0
+
+ num_iterations: int = 1
+
+ # dataset
+ dataset_shuffle: Optional[bool] = True
+
def _init_kto(self):
if self.calculate_KL is None:
# Not all losses require a KL calculation
@@ -46,11 +137,104 @@ def _init_kto(self):
def __post_init__(self):
if self.rlhf_type is None:
return
- default_loss_type = {'kto': 'kto', 'dpo': 'sigmoid'}
+ default_loss_type = {'kto': 'kto', 'dpo': 'sigmoid', 'grpo': 'grpo'}
if self.loss_type is None:
self.loss_type = default_loss_type.get(self.rlhf_type)
if self.rlhf_type == 'kto':
self._init_kto()
+ if self.rlhf_type == 'grpo':
+ self._init_grpo()
+
+ def _init_grpo(self):
+
+ def _check_not_supported():
+ if self.async_generate:
+ raise ValueError('async_generate is not supported for Megatron GRPO right now')
+ if self.sync_ref_model:
+ raise ValueError('sync_ref_model is not supported for Megatron GRPO right now')
+ if not self.dataset_shuffle:
+ raise ValueError('dataset_shuffle false is not supported for Megatron GRPO')
+ if self.multi_turn_scheduler:
+ raise ValueError('multi_turn_scheduler is not supported for Megatron GRPO right now')
+ if self.log_entropy:
+ raise ValueError('log_entropy is not supported for Megatron GRPO right now')
+ if self.top_entropy_quantile < 1:
+ raise ValueError('top_entropy_quantile < 1 is not supported for Megatron GRPO right now')
+ if self.num_iterations > 1:
+ raise ValueError('num_iterations > 1 is not supported for Megatron GRPO right now')
+ if self.kl_in_reward:
+ raise ValueError('kl_in_reward is not supported for Megatron GRPO right now')
+ if self.advantage_estimator != 'grpo':
+ raise ValueError('advantage_estimator must be grpo for Megatron GRPO right now')
+
+ def _check_batch_params():
+ # Set default values if both are None
+ if self.generation_batch_size is None and self.steps_per_generation is None:
+ self.steps_per_generation = 1
+ self.generation_batch_size = self.global_batch_size * self.steps_per_generation
+ # Both configured - error
+ elif self.generation_batch_size is not None and self.steps_per_generation is not None:
+ raise ValueError("'generation_batch_size' and 'steps_per_generation' cannot be both configured")
+ # Only generation_batch_size configured
+ elif self.generation_batch_size is not None:
+ if self.generation_batch_size % self.global_batch_size != 0:
+ raise ValueError(f'generation_batch_size ({self.generation_batch_size}) '
+ f'must be divisible by global_batch_size ({self.global_batch_size})')
+ self.steps_per_generation = self.generation_batch_size // self.global_batch_size
+ # Only steps_per_generation configured
+ else:
+ self.generation_batch_size = self.global_batch_size * self.steps_per_generation
+
+ world_size = torch.distributed.get_world_size()
+ dp_size = world_size // (
+ self.pipeline_model_parallel_size * self.tensor_model_parallel_size * self.context_parallel_size)
+ num_rollout_prompt = self.generation_batch_size // self.num_generations
+ if num_rollout_prompt % dp_size != 0:
+ raise ValueError(f'num_rollout_prompt ({num_rollout_prompt}) = generation_batch_size '
+ f'({self.generation_batch_size}) // num_generations ({self.num_generations}) '
+ f'must be divisible by dp_size ({dp_size}). '
+ f'Please adjust generation_batch_size/steps_per_generation/num_generations.')
+
+ per_device_num_rollout_prompt = num_rollout_prompt // dp_size
+
+ if per_device_num_rollout_prompt % self.micro_batch_size != 0:
+ raise ValueError(f'Per-device rollout prompt count ({per_device_num_rollout_prompt}) = '
+ f'(generation_batch_size ({self.generation_batch_size}) // '
+ f'num_generations ({self.num_generations})) // dp_size ({dp_size}) '
+ f'must be divisible by micro_batch_size ({self.micro_batch_size}). '
+ f'Please adjust arguments to satisfy: '
+ f'(generation_batch_size // num_generations) // dp_size % '
+ f'micro_batch_size == 0')
+
+ self.per_device_generation_batch_size = self.generation_batch_size // world_size
+
+ _check_not_supported()
+ _check_batch_params()
+ # default loss_type if no loss_type is provided
+ assert self.loss_type in ['grpo', 'bnpo', 'dr_grpo'], \
+ f'loss_type must be one of [grpo, bnpo, dr_grpo], but got {self.loss_type}'
+ self.remove_unused_columns = False
+ logger.info(f'Setting args.remove_unused_columns: {self.remove_unused_columns}')
+ if self.truncation_strategy is None:
+ self.truncation_strategy = 'left'
+ assert self.truncation_strategy in ['left', 'delete'
+ ], ("GRPO requires `truncation_strategy 'left' or 'delete'`, "
+ f"Current value: `truncation_strategy='{self.truncation_strategy}'`."
+ ) # noqa
+ if self.beta is None:
+ self.beta = 0.04 # https://arxiv.org/abs/2402.03300
+ if self.async_generate:
+ logger.info('Using async mode. This is a approximate version which '
+ 'will use the old weights to generate responses to accelerate. '
+ 'This will ignore the `CLIP` of advantages, if you found the training '
+ 'is unstable, you may consider using --async_generate false.')
+ if 'soft_overlong' in self.reward_funcs:
+ assert self.soft_cache_length is not None, \
+ 'The soft_cache_length must be set when using soft overlong rewards.'
+ if self.soft_max_length is None:
+ self.soft_max_length = self.max_completion_length
+ logger.info(f'Auto-configured soft_max_length = max_completion_length {self.max_completion_length}')
+ assert self.use_vllm, 'use_vllm must be True for Megatron GRPO'
@dataclass
diff --git a/swift/megatron/argument/rlhf_args.py b/swift/megatron/argument/rlhf_args.py
index 127f928ef1..5106175e2f 100644
--- a/swift/megatron/argument/rlhf_args.py
+++ b/swift/megatron/argument/rlhf_args.py
@@ -7,7 +7,7 @@
@dataclass
class MegatronRLHFArguments(MegatronTrainArguments):
- rlhf_type: Literal['dpo', 'kto', 'rm'] = 'dpo'
+ rlhf_type: Literal['dpo', 'kto', 'grpo', 'rm'] = 'dpo'
loss_scale: str = 'last_round'
calculate_per_token_loss: bool = False
diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py
index 3a25bda9fe..d27f5aabb2 100644
--- a/swift/megatron/train/rlhf.py
+++ b/swift/megatron/train/rlhf.py
@@ -2,9 +2,10 @@
from typing import List, Optional, Union
from swift.llm.train.kto import prepare_kto_dataset
-from swift.utils import get_logger
+from swift.trainers.rlhf_trainer.utils import identity_data_collator
+from swift.utils import get_current_device, get_logger, is_last_rank
from ..argument import MegatronRLHFArguments
-from ..trainers import MegatronDPOTrainer, MegatronKTOTrainer, MegatronRewardTrainer
+from ..trainers import MegatronDPOTrainer, MegatronGRPOTrainer, MegatronKTOTrainer, MegatronRewardTrainer
from .sft import MegatronSft
logger = get_logger()
@@ -16,18 +17,29 @@ class MegatronRLHF(MegatronSft):
def prepare_trainer(self):
args = self.args
- trainer_mapping = {'dpo': MegatronDPOTrainer, 'kto': MegatronKTOTrainer, 'rm': MegatronRewardTrainer}
+ trainer_mapping = {
+ 'dpo': MegatronDPOTrainer,
+ 'grpo': MegatronGRPOTrainer,
+ 'kto': MegatronKTOTrainer,
+ 'rm': MegatronRewardTrainer
+ }
trainer_cls = trainer_mapping.get(args.rlhf_type)
if trainer_cls is None:
raise ValueError(f'The current Megatron-SWIFT does not support rlhf_type: {args.rlhf_type}.')
- return trainer_cls(args, self.template)
+ kwargs = {}
+ if args.rlhf_type == 'grpo':
+ kwargs['vllm_client'] = self._prepare_vllm_client()
+ return trainer_cls(args, self.template, **kwargs)
def _prepare_template(self) -> None:
super()._prepare_template()
- if self.args.rlhf_type == 'kto':
- self.template.set_mode('kto')
- else:
- self.template.set_mode('rlhf')
+ model_mapping = {'grpo': 'train', 'kto': 'kto'}
+ self.template.set_mode(model_mapping.get(self.args.rlhf_type, 'rlhf'))
+
+ def _get_data_collator(self):
+ if self.args.rlhf_type == 'grpo':
+ return identity_data_collator
+ return super()._get_data_collator()
def _get_dataset(self):
args = self.args
@@ -36,6 +48,23 @@ def _get_dataset(self):
train_dataset, val_dataset = prepare_kto_dataset(args, train_dataset, val_dataset)
return train_dataset, val_dataset
+ def _prepare_vllm_client(self):
+ if self.args.rlhf_type != 'grpo' or (self.args.vllm_mode != 'server'):
+ return
+ from swift.trainers.rlhf_trainer.vllm_client import VLLMClient
+ vllm_client = None
+ if is_last_rank():
+ logger.info('Start connecting to vLLM server')
+ vllm_client = VLLMClient(
+ base_urls=self.args.vllm_server_base_url,
+ hosts=self.args.vllm_server_host,
+ server_ports=self.args.vllm_server_port,
+ connection_timeout=self.args.vllm_server_timeout)
+ vllm_client.close_communicator()
+ vllm_client.init_communicator(device=get_current_device())
+ logger.info('Connected to vLLM server')
+ return vllm_client
+
def megatron_rlhf_main(args: Optional[Union[List[str], MegatronRLHFArguments]] = None):
return MegatronRLHF(args).main()
diff --git a/swift/megatron/trainers/__init__.py b/swift/megatron/trainers/__init__.py
index a70e1d7f10..80cf16fe22 100644
--- a/swift/megatron/trainers/__init__.py
+++ b/swift/megatron/trainers/__init__.py
@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .dpo_trainer import MegatronDPOTrainer
+from .grpo_trainer import MegatronGRPOTrainer
from .kto_trainer import MegatronKTOTrainer
from .reward_trainer import MegatronRewardTrainer
from .trainer import MegatronTrainer
diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py
index 164fe0ee0a..2b6d938cc4 100644
--- a/swift/megatron/trainers/base.py
+++ b/swift/megatron/trainers/base.py
@@ -31,7 +31,7 @@
from packaging import version
from tqdm.auto import tqdm
-from swift.llm import dynamic_gradient_checkpointing
+from swift.llm import Template, dynamic_gradient_checkpointing
from swift.plugin import MeanMetric
from swift.trainers import SwiftMixin
from swift.utils import JsonlWriter, deep_getattr, format_time, get_logger
@@ -50,11 +50,12 @@
class BaseMegatronTrainer(ABC):
- def __init__(self, args, template):
+ def __init__(self, args, template: Template):
self.args = args
self.template = template
self.stimer = StragglerDetector()
self.unwrapped_models = []
+ self.wrapped_models = []
self.peft_models = []
self._bridge = None
logging_path = os.path.join(args.save, 'logging.jsonl')
@@ -86,9 +87,11 @@ def initialize_megatron(*_args, **kwargs):
args = get_args()
data_parallel_size = mpu.get_data_parallel_world_size()
step_batch_size = args.micro_batch_size * data_parallel_size
+ num_generations = args.num_generations if hasattr(args, 'num_generations') else 1
if args.train_iters is None and args.max_epochs is not None:
if hasattr(train_dataset, '__len__'):
dataset_sample = len(train_dataset) // step_batch_size * step_batch_size
+ dataset_sample = dataset_sample * num_generations
args.train_iters = dataset_sample * args.max_epochs // args.global_batch_size
else:
raise ValueError(
@@ -98,6 +101,7 @@ def initialize_megatron(*_args, **kwargs):
args.eval_iters = 0
elif hasattr(val_dataset, '__len__'):
dataset_sample = len(val_dataset) // step_batch_size * step_batch_size
+ dataset_sample = dataset_sample * num_generations
args.eval_iters = max(dataset_sample // args.global_batch_size, 1)
else:
raise ValueError(
@@ -419,6 +423,7 @@ def new_model_provider_func(*_args, **kwargs):
with self._patch_load_state_dict(self._load_base_checkpoint), self._patch_get_param_groups():
model, optimizer, opt_param_scheduler = self._origin_setup_model_and_optimizer(
new_model_provider_func, model_type, *_args, **kwargs)
+ self.wrapped_models = model
if args.initialize_embedding:
for m in self.unwrapped_models:
self._initialize_embedding(m)
@@ -937,6 +942,7 @@ def _patch_megatron(self):
# support max_epochs
self._origin_train_step = training.train_step
training.train_step = self.train_step
+ self._origin_cyclic_iter = training.cyclic_iter
training.cyclic_iter = self.new_cyclic_iter
# patch training_log
self._origin_training_log = training.training_log
diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py
new file mode 100644
index 0000000000..d3253d4b39
--- /dev/null
+++ b/swift/megatron/trainers/grpo_trainer.py
@@ -0,0 +1,1405 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import base64
+import gc
+import inspect
+import os
+import uuid
+from collections import defaultdict
+from contextlib import contextmanager, nullcontext
+from copy import copy, deepcopy
+from functools import partial
+from typing import Any, Dict, List, Tuple, Union
+
+import json
+import pandas as pd
+import torch
+import torch.nn as nn
+from accelerate.utils import broadcast_object_list
+from dacite import from_dict
+from megatron.core import mpu
+from megatron.core.rerun_state_machine import RerunDataIterator
+from megatron.training import get_args, get_wandb_writer, training
+from trl.trainer.grpo_trainer import nanstd
+from vllm.distributed import parallel_state as vllm_ps
+
+from swift.llm import RequestConfig, RolloutInferRequest, RowPreprocessor, Template, to_device
+from swift.llm.infer.protocol import RolloutOutput
+from swift.llm.template.template_inputs import TemplateInputs
+from swift.plugin import MultiTurnScheduler, multi_turns, orms
+from swift.trainers.rlhf_trainer.grpo_trainer import DataType
+from swift.trainers.rlhf_trainer.utils import (FlattenedTensorBucket, aggressive_empty_cache,
+ replace_assistant_response_with_ids, set_expandable_segments)
+from swift.utils import (get_current_device, get_logger, is_last_rank, is_vllm_available, is_wandb_available,
+ remove_response)
+from ..argument import MegatronArguments, MegatronRLHFArguments
+from ..utils import forward_step_helper
+from .rlhf_mixin import MegatronRLHFTrainer
+from .utils import (gather, gather_object, get_swift_datasets_provider, load_megatron_model_to_gpu,
+ load_megatron_optimizer, offload_megatron_model_to_cpu, offload_megatron_optimizer,
+ profiling_context, profiling_decorator)
+
+if is_wandb_available():
+ import wandb
+
+logger = get_logger()
+
+
+class MegatronGRPOTrainer(MegatronRLHFTrainer):
+
+ def __init__(self, args: MegatronRLHFArguments, template: Template, **kwargs):
+ self.vllm_client = kwargs.pop('vllm_client')
+ super().__init__(args, template)
+ self.args = args
+ self.hf_model_dir = args.model_info.model_dir
+ self.processing_class = self.template.processor
+ self._prepare_metrics()
+ self._prepare_template_data_collator()
+ self._init_grpo_params()
+ self._prepare_rewards()
+ self._prepare_scheduler() # TODO
+ self._prepare_rollout_engine()
+
+ def train(self, train_dataset, val_dataset, data_collator):
+ # Store dataset provider for lazy resample iterator initialization
+ if self.dynamic_sample:
+ self._train_valid_test_dataset_provider = get_swift_datasets_provider(train_dataset, val_dataset)
+ self._train_valid_test_dataset_provider.is_distributed = True
+ super().train(train_dataset, val_dataset, data_collator)
+
+ def _prepare_template_data_collator(self):
+ template = self.template
+ args = self.args
+ data_collator = template.data_collator
+ padding_to = None
+ if args.tensor_model_parallel_size > 1 and args.sequence_parallel:
+ padding_to = args.tensor_model_parallel_size
+ if args.context_parallel_size > 1:
+ padding_to = (padding_to or 1) * args.context_parallel_size
+ if args.fp8_format:
+ padding_to = max((padding_to or 1) * 8, 16)
+ logger.info(f'padding_to: {padding_to}')
+ data_collator = partial(data_collator, padding_to=padding_to)
+ template.data_collator = data_collator
+
+ def _init_grpo_params(self):
+ args: MegatronArguments = self.args
+ # distributed params
+ self.world_size = torch.distributed.get_world_size()
+ self.process_index = torch.distributed.get_rank()
+ self.is_main_process = is_last_rank()
+ self.device = get_current_device()
+ # algorithm params
+ self.num_generations = args.num_generations # G in the GRPO paper
+ self.beta = args.beta
+ self.temperature = args.temperature
+ self.loss_type = args.loss_type
+ self.max_completion_length = args.max_completion_length
+ self.epsilon_low = args.epsilon
+ self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
+ self.top_entropy_quantile = args.top_entropy_quantile
+ self.importance_sampling_level = args.importance_sampling_level
+ self.enable_offload = False
+
+ # DAPO, https://arxiv.org/abs/2503.14476
+ self.dynamic_sample = args.dynamic_sample
+ self.max_resample_times = args.max_resample_times
+ self.overlong_filter = args.overlong_filter
+
+ # Dr. GRPO / RLOO / REINFORCE++
+ self.scale_rewards = args.scale_rewards
+ self.advantage_estimator = args.advantage_estimator # TODO
+ self.kl_in_reward = args.kl_in_reward # TODO
+
+ # Entropy mask settings, TODO
+ self.log_entropy = args.log_entropy
+ self.compute_entropy = self.log_entropy or self.top_entropy_quantile < 1.0
+
+ # batch size (completion-level)
+ self.generation_batch_size = args.generation_batch_size
+ self.steps_per_generation = args.steps_per_generation
+ self.global_batch_size = args.global_batch_size
+ self.micro_batch_size = args.micro_batch_size
+ self.per_device_generation_batch_size = args.per_device_generation_batch_size
+
+ # sampling params
+ self.request_config = RequestConfig(
+ n=1,
+ max_tokens=args.max_completion_length,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ top_k=args.top_k,
+ repetition_penalty=args.repetition_penalty,
+ stop=args.stop_words,
+ return_details=True)
+
+ self._step = 0
+ self._last_loaded_step = -1
+ self._rollout_group = None # Will be lazily initialized
+
+ def _prepare_rollout_engine(self):
+ args = self.args
+ self.vllm_mode = args.vllm_mode
+ self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode
+ self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode
+ self.use_vllm = args.use_vllm
+ self.async_generate = args.async_generate # TODO
+ self.vllm_use_async_engine = False
+ self.enable_offload = False
+ self.use_gym_env = False
+ self.enable_server_multi_turn = False # TODO
+ # for multi-turn server, maybe the num of rollout outputs is not equal to the num of rollout inputs
+ assert self.use_vllm
+ if not is_vllm_available():
+ raise ImportError('vLLM is not available and `use_vllm` is set to True. '
+ 'Please install vLLM with `pip install vllm -U` to use it.')
+ if self.vllm_mode == 'server':
+ pass
+ elif self.vllm_mode == 'colocate':
+ if not self.world_size % self.vllm_tensor_parallel_size == 0:
+ raise ValueError(f'vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size '
+ f'({self.world_size}) evenly.')
+
+ self.enable_offload = self.args.offload_model or self.args.offload_optimizer
+ context = self.offload_context if self.enable_offload else nullcontext
+
+ with context():
+ set_expandable_segments(False)
+ self.engine = self.prepare_vllm()
+ if self.args.sleep_level > 0:
+ self.engine.engine.sleep(self.args.sleep_level)
+ set_expandable_segments(True)
+ else:
+ raise ValueError(f'Invalid vllm_mode: {self.vllm_mode}')
+
+ def prepare_vllm(self):
+ from swift.llm.infer.infer_engine import GRPOVllmEngine
+ args = self.args
+ max_num_seqs = self.per_device_generation_batch_size * self.vllm_tensor_parallel_size
+ vllm_template = copy(self.template)
+ vllm_template.padding_free = False
+ engine = GRPOVllmEngine(
+ self.hf_model_dir,
+ args.torch_dtype,
+ model_type=args.model_type,
+ use_async_engine=False,
+ tensor_parallel_size=self.vllm_tensor_parallel_size,
+ gpu_memory_utilization=self.vllm_gpu_memory_utilization,
+ enable_prefix_caching=self.args.vllm_enable_prefix_caching,
+ max_num_seqs=max_num_seqs,
+ enforce_eager=self.args.vllm_enforce_eager,
+ limit_mm_per_prompt=self.args.vllm_limit_mm_per_prompt,
+ enable_sleep_mode=self.args.sleep_level > 0,
+ max_model_len=self.args.vllm_max_model_len,
+ seed=self.process_index // self.vllm_tensor_parallel_size,
+ disable_cascade_attn=self.args.vllm_disable_cascade_attn,
+ load_format='dummy',
+ template=vllm_template,
+ distributed_executor_backend='external_launcher',
+ )
+ if self.vllm_tensor_parallel_size > 1:
+ self.vllm_tp_group = vllm_ps.get_tp_group().device_group
+ self._buffered_inputs = None
+ return engine
+
+ @profiling_decorator
+ def _move_model_to_vllm(self):
+ # Handle LoRA: merge adapters before exporting weights
+ is_lora_training = self.args.train_type == 'lora'
+
+ try:
+ if is_lora_training:
+ self.merge_lora_adapters()
+
+ # Export and load weights incrementally to avoid memory spikes
+ self._export_and_load_weights()
+
+ finally:
+ # Unmerge adapters to restore training state
+ if is_lora_training:
+ self.unmerge_lora_adapters()
+
+ # Reset prefix cache
+ if self.vllm_mode == 'server' and self.is_main_process:
+ self.vllm_client.reset_prefix_cache()
+ elif self.vllm_mode == 'colocate':
+ self.engine.engine.reset_prefix_cache()
+
+ @property
+ def bridge(self):
+ if self._bridge is None:
+ self._bridge = self.args.megatron_model_meta.bridge_cls(disable_tqmd=True)
+ return self._bridge
+
+ def _export_and_load_weights(self):
+ """
+ Export weights from Megatron models and load to vLLM incrementally.
+
+ For colocate mode: llm_model.load_weights accepts an iterator, so pass it directly.
+ For server mode: Process weights in buckets to avoid memory spikes.
+ """
+ # Export weights returns an iterator
+ with profiling_context(self, 'export_weights'):
+ weight_iterator = self.bridge.export_weights(self.unwrapped_models)
+
+ if self.vllm_mode == 'colocate':
+ # Colocate mode: load_weights supports iterator, pass directly
+ llm_model = self.engine.inner_model
+ llm_model.load_weights(weight_iterator)
+ elif self.vllm_mode == 'server' and self.is_main_process:
+ # Server mode: process in buckets and sync with flattened tensors
+ self._load_weights_to_server_in_buckets(weight_iterator)
+
+ def _load_weights_to_server_in_buckets(self, weight_iterator):
+ """
+ Load weights to vLLM server in buckets using FlattenedTensorBucket.
+
+ Args:
+ weight_iterator: Iterator of (name, tensor) tuples from export_weights
+ """
+ # Get bucket size from environment or use default
+ bucket_size_mb = int(os.environ.get('SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE', 512))
+ bucket_size_bytes = bucket_size_mb * 1024 * 1024
+
+ current_bucket = []
+ current_size = 0
+
+ for name, param in weight_iterator:
+ param_size = param.numel() * param.element_size()
+ current_bucket.append((name, param))
+ current_size += param_size
+
+ # If adding this param would exceed bucket size, process current bucket first
+ if current_size > bucket_size_bytes and current_bucket:
+ self._sync_bucket_to_server(current_bucket)
+ current_bucket = []
+ current_size = 0
+
+ # Process remaining parameters in the last bucket
+ if current_bucket:
+ self._sync_bucket_to_server(current_bucket)
+
+ def _sync_bucket_to_server(self, bucket_params: List[Tuple[str, torch.Tensor]]):
+ """
+ Synchronize a bucket of parameters to vLLM server using flattened tensors.
+
+ Args:
+ bucket_params: List of (name, tensor) tuples to sync
+ """
+ if not bucket_params:
+ return
+
+ # Create FlattenedTensorBucket for efficient transfer
+ bucket = FlattenedTensorBucket(named_tensors=bucket_params)
+ metadatas = bucket.get_metadata()
+ flattened_tensor = bucket.get_flattened_tensor()
+
+ # Directly call vllm_client to update weights
+ self.vllm_client.update_flattened_params(metadatas, flattened_tensor)
+
+ # Clean up to free memory immediately
+ del bucket, metadatas, flattened_tensor
+
+ def _prepare_rewards(self):
+ # TODO: reward model
+ args = self.args
+ reward_funcs = args.reward_funcs
+ if not isinstance(reward_funcs, list):
+ reward_funcs = [reward_funcs]
+
+ # initilize reward functions
+ if reward_funcs:
+ for i, reward_func in enumerate(reward_funcs):
+ if reward_func in orms:
+ reward_func_class = orms[reward_func]
+ reward_func_args = list(inspect.signature(reward_func_class.__init__).parameters)
+ reward_func_kwargs = {
+ key: getattr(args, key)
+ for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key)
+ }
+ if 'tokenizer' in reward_func_args:
+ reward_func_kwargs['tokenizer'] = self.processing_class
+ reward_funcs[i] = reward_func_class(**reward_func_kwargs)
+ elif not callable(reward_func):
+ raise ValueError(f'reward_function {reward_func} is not implemented in swift.plugin')
+
+ # get reward name for logging
+ self.reward_funcs = reward_funcs
+ self.reward_func_names = []
+ for reward_func in reward_funcs:
+ if inspect.isfunction(reward_func):
+ reward_func_name = reward_func.__name__
+ else:
+ reward_func_name = reward_func.__class__.__name__
+ self.reward_func_names.append(reward_func_name)
+
+ # set reward weights
+ if args.reward_weights is not None:
+ if len(args.reward_weights) != len(reward_funcs):
+ raise ValueError(f'Number of reward weights ({len(args.reward_weights)}) must match number of reward '
+ f'functions ({len(reward_funcs)})')
+ self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32).to(self.device)
+ else:
+ self.reward_weights = torch.ones(len(self.reward_func_names), dtype=torch.float32).to(self.device)
+
+ # TODO: reward models
+ self.reward_model_plugins = [None] * len(self.reward_funcs)
+
+ assert self.reward_funcs, 'reward_funcs is not set'
+
+ def _prepare_scheduler(self):
+ """Prepare multi-turn scheduler"""
+ args = self.args
+
+ self.multi_turn_scheduler = None
+ if not hasattr(args, 'multi_turn_scheduler'):
+ return
+
+ if args.multi_turn_scheduler:
+ if isinstance(args.multi_turn_scheduler, str):
+ assert args.multi_turn_scheduler in multi_turns
+ multi_turn_scheduler = multi_turns[args.multi_turn_scheduler](max_turns=args.max_turns)
+ self.multi_turn_scheduler: MultiTurnScheduler = multi_turn_scheduler
+ else:
+ assert isinstance(args.multi_turn_scheduler, MultiTurnScheduler)
+ self.multi_turn_scheduler: MultiTurnScheduler = args.multi_turn_scheduler
+
+ def _get_rollout_group(self):
+ """
+ Get or create the rollout process group (TP×PP×CP).
+
+ The rollout group is used for:
+ 1. Data slicing: distributing rollout data across all model parallel ranks (including CP)
+ 2. Gather operations: collecting results from all model parallel ranks (including CP)
+
+ Note: MODEL_PARALLEL_GROUP only includes TP×PP, but we need TP×PP×CP for correct
+ data distribution during rollout phase.
+
+ Key insight: ranks with the same DP index but different TP/PP/CP indices should be
+ in the same rollout group. These ranks will:
+ - During rollout: each process different data slices
+ - During training: TP/PP ranks process same data (model split), CP ranks process same data (sequence split)
+ - During gather: collect all data from TP×PP×CP ranks for training
+ """
+ if self._rollout_group is not None:
+ return self._rollout_group
+
+ cp_size = mpu.get_context_parallel_world_size()
+ if cp_size == 1:
+ # No CP, use the standard MODEL_PARALLEL_GROUP
+ self._rollout_group = mpu.get_model_parallel_group()
+ return self._rollout_group
+
+ # Get parallel dimensions
+ tp_size = mpu.get_tensor_model_parallel_world_size()
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
+ dp_size = mpu.get_data_parallel_world_size()
+ global_rank = torch.distributed.get_rank()
+
+ # Calculate rollout group size
+ rollout_group_size = tp_size * pp_size * cp_size
+
+ # Simple and reliable method: assume ranks are organized in contiguous blocks per DP group
+ # This is typically true for the default order (tp-cp-ep-dp-pp)
+ # Each DP group has rollout_group_size consecutive ranks
+ ranks_per_dp_group = rollout_group_size
+ my_dp_block_index = global_rank // ranks_per_dp_group
+
+ # Calculate the rank range for my rollout group
+ group_start = my_dp_block_index * ranks_per_dp_group
+
+ # Create all rollout groups (must be done on all ranks)
+ if not hasattr(self, '_rollout_groups_created'):
+ for dp_idx in range(dp_size):
+ group_start = dp_idx * ranks_per_dp_group
+ group_ranks = list(range(group_start, min(group_start + ranks_per_dp_group, self.world_size)))
+ group = torch.distributed.new_group(ranks=group_ranks, group_desc='ROLLOUT_GROUP')
+ if global_rank in group_ranks:
+ self._rollout_group = group
+ self._rollout_groups_created = True
+
+ return self._rollout_group
+
+ def _init_resample_data_iterator(self):
+ """
+ Initialize an independent data iterator for dynamic resampling (lazy initialization).
+
+ This method is called lazily during the first dynamic resampling, ensuring that
+ pretrain() has already called initialize_megatron() to properly set up all args.
+ Uses a different seed (args.seed + 1) to avoid overlapping with training samples.
+
+ Note: pretrain() will automatically reset the random seed back to args.seed
+ after this method completes, so we don't need manual state restoration.
+
+ Args:
+ train_valid_test_dataset_provider: Dataset provider function
+
+ Returns:
+ train_data_iterator: Independent data iterator with different random seed
+ """
+ from megatron.training.training import build_train_valid_test_data_iterators
+ from megatron.training.initialize import _set_random_seed
+ from megatron.training import training
+ training.cyclic_iter = self._origin_cyclic_iter
+ args = get_args()
+
+ train_valid_test_dataset_provider = self._train_valid_test_dataset_provider
+ # Use different seed for resample iterator (offset by 1 to avoid overlap)
+ resample_seed = getattr(args, 'seed', 42) + 1
+ try:
+ # Set new seed for resample iterator creation
+ _set_random_seed(
+ resample_seed,
+ args.data_parallel_random_init,
+ args.te_rng_tracker,
+ args.inference_rng_tracker,
+ use_cudagraphable_rng=args.enable_cuda_graph,
+ )
+
+ # Build data iterators with new seed
+ # TODO: VPP (Virtual Pipeline Parallelism)
+ resample_data_iterator, _, _ = (build_train_valid_test_data_iterators(train_valid_test_dataset_provider))
+ finally:
+ # Restore original random states to avoid affecting training
+ _set_random_seed(
+ args.seed,
+ args.data_parallel_random_init,
+ args.te_rng_tracker,
+ args.inference_rng_tracker,
+ use_cudagraphable_rng=args.enable_cuda_graph,
+ )
+ return resample_data_iterator
+
+ def _replace_data_iterator(self, data_iterator, model):
+ if self._step % self.steps_per_generation == 0:
+ num_iters_per_step = self.get_num_iters_per_step()
+ rollout_batch = []
+ for _ in range(num_iters_per_step):
+ rollout_batch.extend(next(data_iterator))
+ micro_batch_data = self._generate_and_score_completions(rollout_batch)
+ num_mini_batch = self.global_batch_size // (self.micro_batch_size * mpu.get_data_parallel_world_size())
+ mini_batch_data = [
+ micro_batch_data[i:i + num_mini_batch] for i in range(0, len(micro_batch_data), num_mini_batch)
+ ]
+ assert len(mini_batch_data) == self.steps_per_generation
+ self._buffered_inputs = mini_batch_data
+ self._step += 1
+ inputs = self._buffered_inputs[self._step % self.steps_per_generation]
+ return RerunDataIterator(iter(inputs))
+
+ def _generate_and_score_completions(self, batch):
+ # Get or create the rollout group (TP×PP×CP)
+ rollout_group = self._get_rollout_group()
+
+ rollout_batch = self.get_local_rollout_batch(batch)
+
+ rollout_batch = self._generate_completions(rollout_batch)
+
+ rewards_per_func = self._score_completions(rollout_batch)
+
+ # Dynamic sampling for std=0 groups (DAPO)
+ if self.dynamic_sample:
+ rollout_batch, rewards_per_func = self._dynamic_sampling(rollout_batch, rewards_per_func)
+
+ advantages = self._compute_advantages(rollout_batch, rewards_per_func)
+
+ def _get_encoded_batch(rollout_batch, advantages):
+ template = self.template
+ with self._template_context(template):
+ encoded_batch = [template.encode(data, return_length=True) for data in rollout_batch]
+ encoded_batch = to_device(template.data_collator(encoded_batch), self.device)
+ labels = encoded_batch['labels']
+ assert self.template.padding_free
+ position_ids = encoded_batch.get('text_position_ids')
+ if position_ids is None:
+ position_ids = encoded_batch.get('position_ids')
+ squeezed_position_ids = position_ids.squeeze()
+ assert squeezed_position_ids is not None
+ # Remove trailing padding zeros from position_ids to avoid interference
+ # Find the last non-zero position
+ last_nonzero_idx = (squeezed_position_ids != 0).nonzero(as_tuple=True)[0]
+ if len(last_nonzero_idx) > 0:
+ # Keep only up to the last non-zero position + 1 to include the last valid position
+ squeezed_position_ids = squeezed_position_ids[:last_nonzero_idx[-1] + 1]
+
+ # Calculate lengths based on sequence boundaries (position_ids == 0)
+ lengths = torch.diff(
+ torch.cat([(squeezed_position_ids == 0).nonzero(as_tuple=True)[0],
+ torch.tensor([len(squeezed_position_ids)]).to(squeezed_position_ids.device)]))
+ advantages = torch.repeat_interleave(advantages, lengths)
+ truncated_mask = torch.tensor([b['is_truncated'] for b in rollout_batch],
+ dtype=torch.bool,
+ device=self.device)
+ truncated_mask = torch.repeat_interleave(truncated_mask, lengths).unsqueeze(0)
+ padding_length = labels.shape[1] - truncated_mask.shape[1]
+ if padding_length > 0:
+ padding = torch.zeros((1, padding_length), device=truncated_mask.device, dtype=truncated_mask.dtype)
+ truncated_mask = torch.cat([truncated_mask, padding], dim=1)
+ # Pad advantages to match the original position_ids length
+ original_length = position_ids.shape[1]
+ if advantages.shape[0] < original_length:
+ padding_length = original_length - advantages.shape[0]
+ padding = torch.zeros(padding_length, device=advantages.device, dtype=advantages.dtype)
+ advantages = torch.cat([advantages, padding])
+
+ encoded_batch.update({
+ 'completion_mask': labels != -100,
+ 'truncated_mask': truncated_mask,
+ 'advantages': advantages,
+ 'num_samples': len(rollout_batch),
+ })
+
+ return encoded_batch
+
+ # Step2: ref/old logps
+ total_batch = gather_object(rollout_batch, group=rollout_group)
+ total_advantages = gather(advantages, group=rollout_group)
+ mini_batch_data = []
+ for idx in range(0, len(total_batch), self.micro_batch_size):
+ micro_batch_data = total_batch[idx:idx + self.micro_batch_size]
+ micro_batch_data = self._maybe_replace_response_token(micro_batch_data)
+ micro_batch_advantages = total_advantages[idx:idx + self.micro_batch_size]
+ micro_batch_data = _get_encoded_batch(micro_batch_data, micro_batch_advantages)
+ with profiling_context(self, 'compute_ref_old_logps'):
+ micro_batch_data = self._maybe_compute_logps(micro_batch_data)
+ mini_batch_data.append(micro_batch_data)
+
+ return mini_batch_data
+
+ @profiling_decorator
+ def _generate_completions(self, batch):
+ """
+ Generate completions for a batch of rollout data using vLLM engine.
+
+ This method processes rollout data for the current process, generates completions
+ using the vLLM engine, and merges the results back into the original batch.
+
+ Args:
+ batch: Rollout data assigned to the current process.
+
+ Returns:
+ batch: The input batch with rollout completion results merged in.
+ """
+ # TODO: server mode
+ # add prompt ids and system prompts
+ batch = self._preprocess_inputs(batch)
+ # Step 1: Wake up the engine if it's sleeping (vLLM colocate mode)
+ if self.vllm_mode == 'colocate' and self.engine.inner_model_executor.is_sleeping:
+ wake_up_params = inspect.signature(self.engine.engine.wake_up).parameters
+ # Load weights only (faster and reduces memory peak)
+ kwargs = {'tags': ['weights']} if 'tags' in wake_up_params else {}
+ self.engine.engine.wake_up(**kwargs)
+
+ # Step 2: Load model weights
+ if self._step != self._last_loaded_step:
+ self._move_model_to_vllm()
+ self._last_loaded_step = self._step
+
+ context = self.offload_context if self.enable_offload else nullcontext
+ with context():
+ if (self.vllm_mode == 'colocate' and self.engine.inner_model_executor.is_sleeping
+ and 'tags' in inspect.signature(self.engine.engine.wake_up).parameters):
+ aggressive_empty_cache()
+ set_expandable_segments(False)
+ self.engine.engine.wake_up(tags=['kv_cache'])
+
+ # Step3: Rollout
+ outputs: List[RolloutOutput] = self._rollout(batch)
+
+ # Step4: Sleep to release memory
+ if self.vllm_mode == 'colocate' and self.args.sleep_level > 0:
+ self.engine.engine.reset_prefix_cache()
+ self.engine.engine.sleep(level=self.args.sleep_level)
+ aggressive_empty_cache()
+ set_expandable_segments(True)
+ batch = self.postprocess_rollout_data(batch, outputs)
+
+ return batch
+
+ def _rollout(self, batch) -> List[RolloutOutput]:
+ batch = self._set_inputs_system(batch)
+ request_config = self._get_request_config()
+ # TODO: server mode
+ if self.vllm_mode == 'server':
+ rollout_outputs = self._server_rollout(batch, request_config)
+ elif self.vllm_mode == 'colocate':
+ rollout_outputs = self._colocate_rollout(batch, request_config)
+ # log prompt and completions
+ messages = gather_object([data['messages'] for data in batch])
+ completions = gather_object([data.response.choices[0].message.content for data in rollout_outputs])
+ self._logs['prompt'].extend(self._apply_chat_template_to_messages_list(messages))
+ self._logs['completion'].extend(completions)
+
+ return rollout_outputs
+
+ def postprocess_rollout_data(self, batch, outputs):
+ """
+ Post-process the raw vLLM generation outputs and merge them back into the
+ original input batch.
+
+ Args:
+ batch (List[Dict[str, Any]]):
+ Original rollout samples.
+ outputs (List[RolloutOutput]):
+ outputs from vLLM from vLLM TP group
+
+ Returns:
+ List[Dict[str, Any]]:
+ Updated samples with rollout results merged in.
+ """
+
+ def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], output: RolloutOutput):
+ response = output.response
+ choice = response.choices[0]
+
+ # Step 1: Update or append assistant message
+ if output.messages:
+ input_data['messages'] = output.messages # Override full message history
+ else:
+ # not provided, append
+ messages = input_data['messages']
+ remove_response(messages)
+ messages.append({'role': 'assistant', 'content': choice.message.content})
+ # Step 2: Add token IDs and loss mask
+ if output.response_token_ids:
+ input_data['response_token_ids'] = output.response_token_ids
+ if output.response_loss_mask:
+ input_data['response_loss_mask'] = output.response_loss_mask
+ else:
+ # for single turn, skip tokenizer response
+ input_data['response_token_ids'] = output.response.choices[0].token_ids
+
+ # Step 3: Attach rollout extra info
+ if output.rollout_infos:
+ input_data['rollout_infos'] = output.rollout_infos
+
+ # Step 4: Store finish reason (used for truncation filters etc.)
+ input_data['finish_reason'] = choice.finish_reason
+ input_data['is_truncated'] = choice.finish_reason == 'length'
+
+ return input_data
+
+ assert len(batch) == len(outputs)
+ return [merge_output_input_data(input_data, output) for input_data, output in zip(batch, outputs)]
+
+ def _get_request_config(self) -> RequestConfig:
+ request_config = copy(self.request_config)
+ if self.args.vllm_mode == 'colocate' and self.vllm_tensor_parallel_size > 1:
+ # Set request_config.seed
+ # 1. Ensure that the seed for vLLM Engines within each TP (Tensor Parallelism) group is the same;
+ # otherwise, the program may hang.
+ # 2. Ensure that the seed for vLLM Engines across different TP groups is different;
+ # otherwise, identical completions will be generated.
+ batch_size = self.per_device_generation_batch_size
+ batch_size *= self.vllm_tensor_parallel_size
+ # Since the TP (Tensor Parallelism) group gathers the inputs,
+ # multiply the batch size by the TP parallel size.
+ request_config.seed = batch_size * (self.process_index // self.vllm_tensor_parallel_size)
+
+ return request_config
+
+ def _server_rollout(self,
+ inputs: DataType,
+ request_config: RequestConfig,
+ is_global_inputs: bool = False) -> List[RolloutOutput]:
+ # TODO: async generate
+ infer_requests = self.inputs2requests(inputs)
+
+ if is_global_inputs:
+ per_device_size = len(infer_requests) // self.world_size
+ all_requests = infer_requests
+ all_requests_lengths = [per_device_size] + [0] * (self.world_size - 1)
+ else:
+ all_requests = gather_object(infer_requests)
+ all_requests_lengths = gather_object([len(infer_requests)])
+
+ if not any(requests for requests in all_requests):
+ return []
+
+ if self.is_main_process:
+ all_outputs: List[RolloutOutput] = self.vllm_client.infer(
+ infer_requests=all_requests, request_config=request_config)
+ assert len(all_outputs) == len(all_requests) # TODO: dynamic num of samples
+ else:
+ all_outputs = [None] * len(all_requests)
+
+ if not is_global_inputs:
+ all_outputs = broadcast_object_list(all_outputs, from_process=self.world_size - 1)
+ start_idx = sum(all_requests_lengths[:self.process_index])
+ end_idx = start_idx + all_requests_lengths[self.process_index]
+ outputs = all_outputs[start_idx:end_idx]
+ else:
+ outputs = all_outputs if self.is_main_process else []
+ return outputs
+
+ def _colocate_rollout(self, batch, request_config: RequestConfig):
+ if self.vllm_tensor_parallel_size > 1:
+ local_rank_in_group = torch.distributed.get_rank(group=self.vllm_tp_group)
+ local_input_length = len(batch)
+ all_input_lengths = [None] * self.vllm_tensor_parallel_size
+ torch.distributed.all_gather_object(all_input_lengths, local_input_length, group=self.vllm_tp_group)
+
+ start_idx = sum(all_input_lengths[:local_rank_in_group])
+ end_idx = start_idx + all_input_lengths[local_rank_in_group]
+
+ gathered_batch = [None for _ in range(self.vllm_tensor_parallel_size)]
+ torch.distributed.all_gather_object(gathered_batch, batch, group=self.vllm_tp_group)
+ batch = [p for sublist in gathered_batch for p in sublist]
+
+ outputs: List[RolloutOutput] = self.engine.infer(infer_requests=batch, request_config=request_config)
+
+ if self.vllm_tensor_parallel_size > 1:
+ outputs = outputs[start_idx:end_idx]
+
+ return outputs
+
+ @profiling_decorator
+ def _score_completions(self, inputs: DataType) -> torch.Tensor:
+ """Score completions using all reward functions.
+
+ Args:
+ inputs: List of input examples, each containing a 'messages' list with conversation history
+
+ Returns:
+ rewards_per_func: Tensor of shape (num_examples, num_reward_funcs) with local reward values
+ """
+ # Compute rewards using reward functions
+ local_rewards_per_func = self._compute_rewards_per_func(inputs)
+
+ return local_rewards_per_func
+
+ def _compute_rewards_per_func(self, batch: DataType) -> torch.Tensor:
+ """Compute rewards using all reward functions"""
+ device = self.device
+ rewards_per_func = torch.zeros((len(batch), len(self.reward_funcs)), device=device)
+ completions = [inp['messages'][-1]['content'] for inp in batch]
+ reward_kwargs = {} # TODO: training step info
+ for i, (reward_func, reward_model_plugin, reward_func_name) in enumerate(
+ zip(self.reward_funcs, self.reward_model_plugins, self.reward_func_names)):
+ with profiling_context(self, reward_func_name):
+ # reward model
+ if isinstance(reward_func, nn.Module):
+ output_reward_func = reward_model_plugin(inputs=batch, **reward_kwargs)
+ # reward function
+ else:
+ # Repeat all input columns (but "messages" and "completion") to match the number of generations
+ reward_kwargs.update(RowPreprocessor.rows_to_batched(batch))
+ output_reward_func = reward_func(completions, **reward_kwargs)
+ output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
+ rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
+
+ # If all reward functions return None for a given row, issue a detailed warning
+ if torch.isnan(rewards_per_func).all(dim=1).any():
+ nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
+ row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()}
+ row_reward_kwargs['completion'] = completions[nan_row_idx]
+ logger.warning(f'All reward functions returned None for the following kwargs: {row_reward_kwargs}. '
+ 'Please ensure that at least one reward function returns a valid reward.')
+
+ return rewards_per_func
+
+ def _compute_advantages(self, batch: DataType, rewards_per_func: torch.Tensor) -> torch.Tensor:
+ """Compute advantages for RL training."""
+
+ def normalize_advantages(advantages: torch.Tensor, rewards_std: torch.Tensor) -> torch.Tensor:
+ """Normalize advantages if configured; otherwise, return as-is."""
+ if self.scale_rewards != 'none':
+ return advantages / (rewards_std + 1e-4)
+ return advantages
+
+ mode = 'train' if self.unwrapped_models[0].training else 'eval'
+ assert len(batch) == rewards_per_func.shape[0]
+ total_rewards_per_func = gather(rewards_per_func)
+ rewards = (total_rewards_per_func * self.reward_weights.unsqueeze(0)).nansum(dim=1)
+ grouped_rewards = rewards.view(-1, self.num_generations)
+
+ # Compute group statistics
+ group_rewards_mean = grouped_rewards.mean(dim=1)
+
+ # Broadcast stats back to the original shape
+ group_rewards_mean = group_rewards_mean.repeat_interleave(self.num_generations)
+
+ # Compute advantages relative to group mean
+ advantages = rewards - group_rewards_mean
+
+ # Normalize advantages based on scale_rewards setting
+ if self.scale_rewards == 'batch':
+ # Global batch-level normalization
+ rewards_std = rewards.std().expand_as(rewards)
+ elif self.scale_rewards == 'group':
+ # Group-level normalization (default)
+ rewards_std = grouped_rewards.std(dim=1).repeat_interleave(self.num_generations)
+ else: # 'none'
+ rewards_std = None
+
+ if rewards_std is not None:
+ advantages = normalize_advantages(advantages, rewards_std)
+
+ def log_rewards_metrics(rewards: torch.Tensor, rewards_per_func_for_metrics: torch.Tensor):
+ """Log reward statistics for monitoring. Only log once per unique request_id."""
+ # rewards: [prompt_batch_size, self.num_generations]
+ # rewards_per_func_for_metrics: [prompt_batch_size*self.num_generations, self.num_reward_funcs]
+ group_rewards = rewards.view(-1, self.num_generations)
+ rewards_mean = group_rewards.mean(-1).mean().item()
+ # Compute std based on scale_rewards setting for logging
+ if self.scale_rewards in ['group', 'none']:
+ rewards_std = group_rewards.std(-1).mean().item()
+ elif self.scale_rewards == 'batch':
+ rewards_std = rewards.std().item()
+ is_std_zero = torch.isclose(group_rewards.std(dim=1), torch.zeros_like(group_rewards.std(dim=1)))
+
+ self._metrics[mode]['reward'].append(rewards_mean)
+ self._metrics[mode]['reward_std'].append(rewards_std)
+ self._metrics[mode]['frac_reward_zero_std'].append(is_std_zero.float().mean().item())
+
+ # Log per-reward-function statistics using deduplicated rewards_per_func
+ for i, name in enumerate(self.reward_func_names):
+ col = rewards_per_func_for_metrics[:, i]
+ self._metrics[mode][f'rewards/{name}/mean'].append(torch.nanmean(col).item())
+ self._metrics[mode][f'rewards/{name}/std'].append(nanstd(col).item())
+
+ log_rewards_metrics(rewards=grouped_rewards, rewards_per_func_for_metrics=total_rewards_per_func)
+ self._logs['advantages'].extend(advantages.tolist())
+ for i, name in enumerate(self.reward_func_names):
+ self._logs['rewards'][name].extend(total_rewards_per_func[:, i].tolist())
+
+ slice_start = self.process_index * len(batch)
+ slice_end = slice_start + len(batch)
+ advantages = advantages[slice_start:slice_end]
+
+ return advantages
+
+ def _dynamic_sampling(self, rollout_batch: DataType,
+ rewards_per_func: torch.Tensor) -> Tuple[DataType, torch.Tensor]:
+ """
+ Perform dynamic sampling to replace samples with zero-reward-variance groups.
+
+ This method implements DAPO (https://arxiv.org/abs/2503.14476) by replacing
+ samples from groups with zero reward variance (std=0) through resampling.
+
+ Args:
+ rollout_batch: local rollout data samples
+ rewards_per_func: reward per function for local data samples
+ rollout_group: rollout communication group
+
+ Returns:
+ tuple: (rollout_batch, rewards_per_func) with zero-variance groups replaced by resampled data
+ """
+ resample_count = 0
+ valid_samples = []
+ valid_rewards_per_func = []
+ origin_data = (rollout_batch, rewards_per_func)
+
+ while resample_count < self.max_resample_times:
+ # Gather all samples and rewards across rollout group first
+ global_rollout_batch = gather_object(rollout_batch)
+ global_rewards_per_func = gather(rewards_per_func)
+
+ # Compute reward std for the entire global batch
+ # We need to compute std on the gathered data to get a global mask
+ global_rewards = (global_rewards_per_func * self.reward_weights.unsqueeze(0)).nansum(dim=1)
+ grouped_rewards = global_rewards.view(-1, self.num_generations)
+ group_rewards_std = grouped_rewards.std(dim=1).repeat_interleave(self.num_generations)
+ global_valid_mask = (group_rewards_std > 0)
+
+ # Filter valid samples based on std > 0
+ valid_samples.extend([sample for sample, mask in zip(global_rollout_batch, global_valid_mask) if mask])
+ valid_rewards_per_func.append(global_rewards_per_func[global_valid_mask])
+
+ if len(valid_samples) >= self.generation_batch_size:
+ break
+
+ # Lazy initialization of resample_data_iterator
+ # Only initialize when needed, after pretrain() has set up args
+ if not hasattr(self, 'resample_data_iterator') or self.resample_data_iterator is None:
+ self.resample_data_iterator = self._init_resample_data_iterator()
+ num_iters_per_step = self.get_num_iters_per_step()
+ next_rollout_prompt_batch = []
+ for _ in range(num_iters_per_step):
+ next_rollout_prompt_batch.extend(next(self.resample_data_iterator))
+
+ # Repeat num_generations times and get local slice
+ rollout_batch = self.get_local_rollout_batch(next_rollout_prompt_batch)
+
+ # Generate and score new completions
+ rollout_batch = self._generate_completions(rollout_batch)
+ rewards_per_func = self._score_completions(rollout_batch)
+ resample_count += 1
+
+ if len(valid_samples) >= self.generation_batch_size:
+ # Get local slice of valid samples
+ rank = self.process_index
+ per_device_batch_size = self.per_device_generation_batch_size
+ data_slice = slice(rank * per_device_batch_size, (rank + 1) * per_device_batch_size)
+ rollout_batch = valid_samples[:self.generation_batch_size][data_slice]
+ rewards_per_func = torch.cat(valid_rewards_per_func)[:self.generation_batch_size][data_slice]
+ else:
+ logger.warning(f'There are still std=0 groups present after {self.max_resample_times} retries.')
+ rollout_batch, rewards_per_func = origin_data
+
+ return rollout_batch, rewards_per_func
+
+ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]:
+ # TODO: entropy
+ inputs = {k: v for k, v in batch.items() if k not in ['completion_mask', 'advantages', 'truncated_mask']}
+ if self.beta != 0.0:
+ with torch.no_grad(), self.null_ref_context() as ref_models:
+ assert len(ref_models) == 1, 'GRPO currently does not support VPP.'
+ ref_model = ref_models[0]
+ batch['ref_per_token_logps'] = self.model_forward(
+ ref_model, iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps']
+
+ if not self.on_policy:
+ batch['old_per_token_logps'] = self.model_forward(
+ self.unwrapped_models[0], iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps']
+ return batch
+
+ @contextmanager
+ def _disable_maxlength_template_context(self, template: Template):
+ # The max_length for prompt and completion has already been restricted, so there is no need for max_length here.
+ max_length = template.max_length
+ template.max_length = None
+ try:
+ yield
+ finally:
+ template.max_length = max_length
+
+ def _maybe_replace_response_token(self, batch):
+ # maybe replace the response token with the response token ids to avoid repetitive tokenize
+
+ for data in batch:
+ if 'response_token_ids' in data and data['response_token_ids']:
+ loss_mask = None
+ if 'response_loss_mask' in data and data['response_loss_mask']:
+ loss_mask = data['response_loss_mask']
+ # token in token out
+ data['messages'] = replace_assistant_response_with_ids(data['messages'], data['response_token_ids'],
+ loss_mask)
+ return batch
+
+ @property
+ def on_policy(self):
+ return self.steps_per_generation == 1
+
+ @contextmanager
+ def patch_megatron_data_collator(self, data_collator):
+ """
+ Context manager that temporarily patches Megatron's data-loader factory so each
+ prompt-level micro-batch size equals (original micro-batch size // num_generations),
+ required by GRPO. Restores the original size and loader on exit.
+ """
+ origin_build_pretraining_data_loader = training.build_pretraining_data_loader
+
+ def build_pretraining_data_loader(*_args, **kwargs):
+ args = get_args()
+ org_micro_batch_size = args.micro_batch_size
+ # args.micro_batch_size = org_micro_batch_size // self.num_generations
+ res = origin_build_pretraining_data_loader(*_args, **kwargs)
+ args.micro_batch_size = org_micro_batch_size
+ if res is not None and args.dataloader_type != 'external':
+ res.collate_fn = data_collator
+ return res
+
+ training.build_pretraining_data_loader = build_pretraining_data_loader
+ try:
+ yield
+ finally:
+ training.build_pretraining_data_loader = origin_build_pretraining_data_loader
+
+ @profiling_decorator
+ def forward_step(self, data_iterator, model):
+ # train_batch_size
+ # return: output_tensor, loss_func
+ data = self.get_batch(data_iterator)
+ data.pop('loss_scale', None)
+ inputs = {
+ k: v
+ for k, v in data.items() if k not in
+ ['completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', 'truncated_mask']
+ }
+
+ with self.stimer:
+ output_tensor = model(**inputs)
+ return output_tensor, partial(self.loss_func, data=data)
+
+ @profiling_decorator
+ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]):
+ advantages = data['advantages']
+ labels = data['labels']
+ completion_mask = data['completion_mask']
+ packed_seq_params = data['packed_seq_params']
+ truncated_mask = data['truncated_mask']
+ micro_batch_size = self.micro_batch_size
+ # Use full sequence lengths directly (get_logps returns full sequences in CP mode)
+ lengths = packed_seq_params.cu_seqlens_q[1:micro_batch_size
+ + 1] - packed_seq_params.cu_seqlens_q[:micro_batch_size]
+ lengths_with_padding = packed_seq_params.cu_seqlens_q[1:] - packed_seq_params.cu_seqlens_q[:-1]
+
+ # get_logps with per_token=True now returns full sequences (all_gather in CP mode)
+ per_token_logps = self.get_logps(
+ output_tensor, labels, packed_seq_params, packed_seq_params.num_samples, per_token=True)
+
+ if self.args.overlong_filter and truncated_mask.any():
+ completion_mask = completion_mask & (~truncated_mask)
+ if not completion_mask.any():
+ logger.warning('All completions are truncated in this batch. Loss and grad_norm will be 0. '
+ 'Consider increasing max_completion_length')
+
+ if self.beta != 0.0:
+ ref_per_token_logps = data.get('ref_per_token_logps')
+ per_token_kl = (
+ torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1)
+
+ old_per_token_logps = (
+ per_token_logps.detach() if data.get('old_per_token_logps') is None else data['old_per_token_logps'])
+ log_ratio = per_token_logps - old_per_token_logps
+
+ if self.importance_sampling_level == 'token':
+ log_importance_weights = log_ratio
+ elif self.importance_sampling_level in ['sequence', 'sequence_token']:
+ log_ratio_list = torch.split(log_ratio.squeeze(0), lengths_with_padding.tolist())
+ mask_list = torch.split(completion_mask.squeeze(0), lengths_with_padding.tolist())
+ # Optimized: compute weighted sum for each sequence (avoid list comprehension overhead)
+ # Use torch.stack on results instead of intermediate lists
+ seq_weights = torch.stack([(lr * m).sum() / m.sum().clamp(min=1.0)
+ for lr, m in zip(log_ratio_list, mask_list)])
+ seq_level_log_weights = seq_weights.to(log_ratio.dtype).unsqueeze(-1)
+ if self.importance_sampling_level == 'sequence':
+ log_importance_weights = seq_level_log_weights
+ else:
+ seq_level_log_weight = seq_level_log_weights.detach()
+ # Vectorized: use repeat_interleave with tensor directly
+ seq_level_log_weight = torch.repeat_interleave(
+ seq_level_log_weight.squeeze(-1), lengths_with_padding, dim=0).unsqueeze(0)
+ log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight
+ else:
+ raise ValueError(
+ f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' "
+ ",'sequence' and 'sequence_token'.")
+
+ coef_1 = torch.exp(log_importance_weights)
+ coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
+ if self.args.delta is not None:
+ coef_1 = torch.clamp(coef_1, max=self.args.delta)
+
+ if self.template.padding_free:
+ # In padding_free + sequence mode, coef_1 is [num_samples, 1]
+ # We need to expand to [1, total_tokens] for token-level loss computation
+ if self.importance_sampling_level == 'sequence':
+ # Vectorized: expand sequence-level weights to token-level without gradient
+ coef_1 = torch.repeat_interleave(coef_1.squeeze(-1), lengths_with_padding, dim=0).unsqueeze(0)
+ coef_2 = torch.repeat_interleave(coef_2.squeeze(-1), lengths_with_padding, dim=0).unsqueeze(0)
+
+ advantages = advantages[-coef_1.shape[1]:]
+ per_token_loss1 = coef_1 * advantages.unsqueeze(0)
+ per_token_loss2 = coef_2 * advantages.unsqueeze(0)
+ else:
+ raise NotImplementedError
+ # per_token_loss1 = coef_1 * advantages.unsqueeze(1)
+ # per_token_loss2 = coef_2 * advantages.unsqueeze(1)
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
+ if self.beta != 0.0:
+ per_token_loss = per_token_loss + self.beta * per_token_kl
+
+ if self.loss_type == 'grpo':
+ loss_list = torch.split(per_token_loss.squeeze(0), lengths_with_padding.tolist())
+ mask_list = torch.split(completion_mask.squeeze(0), lengths_with_padding.tolist())
+
+ sample_loss = torch.stack([(loss * mask).sum() / mask.sum().clamp(min=1.0)
+ for loss, mask in zip(loss_list[:micro_batch_size], mask_list[:micro_batch_size])
+ ])
+ loss = sample_loss.mean()
+ elif self.loss_type == 'bnpo':
+ loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
+ elif self.loss_type == 'dr_grpo':
+ loss = (per_token_loss * completion_mask).sum() / (micro_batch_size * self.max_completion_length)
+ else:
+ raise ValueError(f'Unknown loss type: {self.loss_type}')
+
+ avg_metric = {
+ 'loss': loss.clone().detach(),
+ }
+ custom_metrics = {}
+ total_lengths = gather(lengths, group=mpu.get_data_parallel_group(with_context_parallel=True))
+ custom_metrics = {
+ 'completions/mean_length': total_lengths.float().mean(),
+ 'completions/max_length': total_lengths.float().max(),
+ 'completions/min_length': total_lengths.float().min(),
+ }
+
+ if self.beta != 0.0:
+ # Unified processing (no CP-specific logic needed)
+ kl_value = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
+ avg_metric['kl'] = kl_value.clone().detach()
+
+ mode = 'train' if self.unwrapped_models[0].training else 'eval'
+ if self._metrics[mode]:
+ addition_metrics = {
+ key: torch.tensor(sum(val) / len(val), device=loss.device)
+ for key, val in self._metrics[mode].items()
+ }
+ avg_metric.update(addition_metrics)
+
+ avg_metric = self._all_reduce_metric(avg_metric)
+
+ reporting_metric = {**avg_metric, **custom_metrics}
+
+ # log_completions
+ if self.log_completions and self.is_main_process and self._step % self.steps_per_generation == 0:
+ table = {
+ 'gen_step': [self._step] * len(self._logs['prompt']),
+ 'prompt': list(self._logs['prompt']),
+ 'completion': list(self._logs['completion']),
+ **{k: list(v)
+ for k, v in self._logs['rewards'].items()},
+ 'advantages': list(self._logs['advantages']),
+ }
+ self.jsonl_writer.append(table)
+ wandb_writer = get_wandb_writer()
+ if wandb_writer:
+ df = pd.DataFrame(table)
+ if self.wandb_log_unique_prompts:
+ df = df.drop_duplicates(subset=['prompt'])
+ # if not self.init_custom_metric:
+ # wandb_writer.define_metric('completions', step_metric='gen_step')
+ # self.init_custom_metric = True
+ wandb_writer.log({'completions': wandb.Table(dataframe=df)})
+
+ return loss, reporting_metric
+
+ def model_forward(self, model, data_iterator, no_grad=True, per_token=False):
+ # used to calculate model forward (logps) in GRPO
+ with self.stimer(bdata=True):
+ data = self.get_batch(data_iterator)
+ data.pop('loss_scale', None)
+ labels = data.get('labels')
+ context = torch.no_grad() if no_grad else nullcontext()
+ with context:
+ output_tensor = forward_step_helper(model, data)
+ packed_seq_params = data['packed_seq_params']
+ data['logps'] = None if labels is None else self.get_logps(
+ output_tensor, labels, data['packed_seq_params'], packed_seq_params.num_samples, per_token=per_token)
+ return data
+
+ @contextmanager
+ def offload_context(self):
+ if self.args.offload_model:
+ offload_megatron_model_to_cpu(self.wrapped_models)
+ if hasattr(self, 'ref_models') and self.ref_models:
+ offload_megatron_model_to_cpu(self.ref_models)
+ if getattr(self, 'optimizer', None) and self.args.offload_optimizer:
+ offload_megatron_optimizer(self.optimizer)
+
+ try:
+ yield
+ finally:
+ # reload (load back) model when exiting context
+ if self.args.offload_model:
+ load_megatron_model_to_gpu(self.wrapped_models)
+ if hasattr(self, 'ref_models') and self.ref_models:
+ load_megatron_model_to_gpu(self.ref_models)
+ if getattr(self, 'optimizer', None) and self.args.offload_optimizer:
+ load_megatron_optimizer(self.optimizer)
+
+ def inputs2requests(self, inputs: DataType) -> List[RolloutInferRequest]:
+ """Convert raw input data into RolloutInferRequest objects"""
+
+ def _process_image_data(image_data: Union[dict, str]) -> str:
+ if isinstance(image_data, dict):
+ if image_data.get('bytes'):
+ return base64.b64encode(image_data['bytes']).decode('utf-8')
+ if image_data.get('path'):
+ return image_data['path']
+ return image_data
+
+ if not inputs:
+ return []
+ args = self.args
+
+ REQUEST_METADATA_FIELDS = ['messages', 'images', 'audios', 'videos', 'tools', 'objects', 'uuid']
+ requests_dicts = []
+
+ for data in inputs:
+ request_data = {key: data[key] for key in REQUEST_METADATA_FIELDS if key in data and data[key] is not None}
+ if 'uuid' not in request_data:
+ request_data['uuid'] = data['request_id']
+ if hasattr(args, 'vllm_server_pass_dataset') and args.vllm_server_pass_dataset:
+ extra_fields = {
+ k: v
+ for k, v in data.items() if k not in REQUEST_METADATA_FIELDS and data[k] is not None
+ }
+ if extra_fields:
+ request_data['data_dict'] = extra_fields
+ elif self.multi_turn_scheduler:
+ base_data_dict = {}
+ if 'data_dict' in data:
+ if isinstance(data['data_dict'], dict):
+ base_data_dict = data['data_dict']
+ else:
+ raise ValueError('data_dict exists but is not a dictionary')
+ extra_data = {
+ k: v
+ for k, v in data.items()
+ if k not in REQUEST_METADATA_FIELDS and k != 'data_dict' and data[k] is not None
+ }
+ final_data_dict = {**extra_data, **base_data_dict}
+ request_data['data_dict'] = final_data_dict if final_data_dict else {}
+
+ requests_dicts.append(request_data)
+
+ for request in requests_dicts:
+ if 'images' in request and request['images']:
+ request['images'] = ([_process_image_data(img) for img in request['images']] if isinstance(
+ request['images'], list) else _process_image_data(request['images']))
+
+ return [from_dict(RolloutInferRequest, request_data) for request_data in requests_dicts]
+
+ def _preprocess_inputs(self, inputs: DataType) -> DataType:
+ """Preprocess inputs before inference"""
+ processed_inputs = self._add_prompt_id_to_inputs(inputs)
+ for input_item in processed_inputs:
+ remove_response(input_item['messages'])
+ return processed_inputs
+
+ def _add_prompt_id_to_inputs(self, inputs: DataType) -> DataType:
+ """Add unique prompt_id and request_id to each input"""
+ if not inputs:
+ return inputs
+
+ all_messages = gather_object([inp['messages'] for inp in inputs])
+ messages_to_prompt_id = {}
+ prompt_id_counter = 0
+
+ for messages in all_messages:
+ key = json.dumps(messages)
+ if key not in messages_to_prompt_id:
+ messages_to_prompt_id[key] = f'prompt_{prompt_id_counter}'
+ prompt_id_counter += 1
+
+ for input_item in inputs:
+ messages = input_item.get('messages')
+ input_item['prompt_id'] = messages_to_prompt_id[json.dumps(messages)]
+ input_item['request_id'] = f'chatcmpl-{str(uuid.uuid4().hex)}'
+
+ return inputs
+
+ def get_num_iters_per_step(self):
+ if hasattr(self, '_num_iters_per_step'):
+ return self._num_iters_per_step
+ # each rollout DP group will generate generation_batch_size / dp_size completions
+ dp_size = mpu.get_data_parallel_world_size()
+ completions_to_rollout = self.generation_batch_size // dp_size
+ # completions will be repeated num_generations times after
+ # so we need to divide num_iters_per_step by num_generations to get prompt batch size
+ prompts_to_rollout = completions_to_rollout // self.num_generations
+ # every iter will generate micro_batch_size prompts
+ num_iters_per_step = prompts_to_rollout // self.micro_batch_size
+ assert num_iters_per_step > 0, (
+ f'num_iters_per_step={num_iters_per_step} <= 0. '
+ f'This means no prompts will be generated'
+ f'generation_batch_size={self.generation_batch_size}, '
+ f'data_parallel_world_size={mpu.get_data_parallel_world_size()}, '
+ f'num_generations={self.num_generations}, '
+ f'micro_batch_size={self.micro_batch_size}. '
+ 'Please adjust these parameters so that '
+ 'generation_batch_size // data_parallel_world_size // num_generations // micro_batch_size >= 1.')
+ self._num_iters_per_step = num_iters_per_step
+ return num_iters_per_step
+
+ def get_local_rollout_batch(self, batch):
+ # repeat num_generations times
+ rollout_group = self._get_rollout_group()
+ global_rollout_batch = [deepcopy(item) for item in batch for _ in range(self.num_generations)]
+ # get local rollout data
+ rollout_rank = torch.distributed.get_rank(group=rollout_group)
+ rollout_group_size = torch.distributed.get_world_size(group=rollout_group)
+
+ per_device_batch_size = self.per_device_generation_batch_size
+ assert rollout_group_size * per_device_batch_size == len(global_rollout_batch)
+ data_slice = slice(rollout_rank * per_device_batch_size, (rollout_rank + 1) * per_device_batch_size)
+ rollout_batch = global_rollout_batch[data_slice]
+ return rollout_batch
+
+ @contextmanager
+ def _template_context(self, template: Template):
+ # The max_length for prompt and completion has already been restricted, so there is no need for max_length here.
+ max_length = template.max_length
+ template.max_length = None
+ try:
+ yield
+ finally:
+ template.max_length = max_length
+
+ def _prepare_metrics(self):
+ args = self.args
+ from swift.utils import JsonlWriter
+ from collections import deque
+ self.log_completions = args.log_completions
+ self.wandb_log_unique_prompts = args.wandb_log_unique_prompts
+ self.jsonl_writer = JsonlWriter(os.path.join(args.save, 'completions.jsonl'))
+ self.init_custom_metric = False
+ self._logs = {
+ 'prompt': deque(maxlen=args.generation_batch_size),
+ 'completion': deque(maxlen=args.generation_batch_size),
+ 'rewards': defaultdict(lambda: deque(maxlen=args.generation_batch_size)),
+ 'advantages': deque(maxlen=args.generation_batch_size),
+ }
+ if is_wandb_available():
+ # when log profiling, the step is different from the step in the training loop
+ # here patch wandb log to pop the step argument
+ from wandb.sdk.wandb_run import Run
+ origin_log = Run.log
+ from functools import wraps
+
+ @wraps(origin_log)
+ def log(self, data: dict[str, Any], step: int | None = None, commit: bool | None = None):
+ return origin_log(self, data, None, commit)
+
+ Run.log = log
+
+ self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)}
+
+ def _apply_chat_template_to_messages_list(self, messages_list: DataType):
+ prompts_text = []
+ for messages in messages_list:
+ remove_response(messages)
+ template_inputs = TemplateInputs.from_dict({'messages': messages})
+ res = self.template.encode(template_inputs)
+ prompts_text.append(self.template.safe_decode(res['input_ids']))
+ return prompts_text
+
+ def _set_inputs_system(self, batch: DataType) -> DataType:
+ """
+ Ensures the system message is consistently set for all conversations in the batch.
+
+ The template handles the user-defined system message. However, in server mode,
+ tokenization occurs on the rollout side. To prevent a mismatch where the system
+ message is set only during training but missing during rollout, this method
+ injects the default system message into each conversation if not already present.
+
+ Args:
+ batch: A list of data items, each containing a 'messages' list.
+
+ Returns:
+ The updated batch with the default system message inserted at the beginning
+ of each conversation that lacks one.
+ """
+
+ if self.vllm_mode != 'server':
+ return batch
+
+ # Return early if no default system message is defined
+ if not self.template.template_meta.default_system:
+ return batch
+
+ # Return early if all conversations already start with a system message
+ if all(data['messages'][0]['role'] == 'system' for data in batch):
+ return batch
+
+ # Insert the default system message at the beginning of each conversation
+ # that doesn't already have one
+ for data in batch:
+ messages = data['messages']
+ if messages[0]['role'] != 'system':
+ messages.insert(0, {'role': 'system', 'content': self.template.template_meta.default_system})
+
+ return batch
diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py
index c004d5f91b..1c4efce1c9 100644
--- a/swift/megatron/trainers/rlhf_mixin.py
+++ b/swift/megatron/trainers/rlhf_mixin.py
@@ -1,11 +1,13 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from contextlib import contextmanager
+import torch
+import torch.distributed as dist
from megatron.core import mpu
from megatron.training import get_args, get_model
from megatron.training.checkpointing import load_checkpoint
from megatron.training.utils import unwrap_model
-from torch.distributed.nn import all_reduce
+from torch.distributed.nn import all_gather, all_reduce
from transformers.utils import ContextManagers
from swift.utils import get_logger
@@ -54,11 +56,18 @@ def null_ref_context(self):
for m in self.peft_models:
m.set_adapter('default')
- def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None):
+ def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None, per_token=False):
args = get_args()
per_token_logps = -output_tensor
loss_mask = labels != -100
per_token_logps = per_token_logps * loss_mask
+ if per_token:
+ # In CP mode, all_gather and reconstruct full sequence
+ if args.context_parallel_size > 1:
+ per_token_logps = self._postprocess_packed_tensor_cp(per_token_logps, packed_seq_params, num_samples
+ or packed_seq_params.num_samples)
+ return per_token_logps
+
if num_samples is None:
num_samples = packed_seq_params.num_samples * 2
cu_seqlens = packed_seq_params.cu_seqlens_q[:num_samples + 1] // args.context_parallel_size
@@ -69,3 +78,59 @@ def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None):
if args.context_parallel_size > 1:
all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group())
return all_logps
+
+ def _postprocess_packed_tensor_cp(self, tensor, packed_seq_params, num_samples):
+ """
+ Generic method: In CP mode, all_gather and reconstruct full tensor sequences.
+ Works for both logps (float) and masks (bool/int).
+
+ Args:
+ tensor: [1, packed_len/cp_size] - CP-split tensor (any dtype)
+ packed_seq_params: PackedSeqParams object
+ num_samples: Number of samples in the batch
+
+ Returns:
+ output_full: [1, packed_len] - Full sequence tensor
+ """
+ args = get_args()
+ cp_size = args.context_parallel_size
+ cp_rank = mpu.get_context_parallel_rank()
+
+ # All-gather across CP ranks
+ output_list = [torch.empty_like(tensor) for _ in range(cp_size)]
+ torch.distributed.all_gather(output_list, tensor.contiguous(), group=mpu.get_context_parallel_group())
+ output_list[cp_rank] = tensor
+
+ # Reconstruct full sequence
+ # Shape: [1, packed_len/cp_size] -> [1, packed_len]
+ cu_seqlens_full = packed_seq_params.cu_seqlens_q
+ cu_seqlens_cp = cu_seqlens_full // cp_size
+
+ # Calculate total packed length
+ total_packed_len = cu_seqlens_full[num_samples].item()
+ output_full = tensor.new_zeros(1, total_packed_len)
+
+ # Reconstruct each sequence
+ for i in range(num_samples):
+ start_full = cu_seqlens_full[i].item()
+ end_full = cu_seqlens_full[i + 1].item()
+ seq_len = end_full - start_full
+
+ # Length of each chunk after CP split
+ chunk_len = seq_len // cp_size
+ half_chunk = chunk_len // 2
+
+ # Concatenate from each CP rank's output (load-balanced split)
+ for j in range(cp_size):
+ o = output_list[j][0]
+ start_cp = cu_seqlens_cp[i].item()
+
+ # Get two half chunks (CP's load-balanced split)
+ o0 = o[start_cp:start_cp + half_chunk]
+ o1 = o[start_cp + half_chunk:start_cp + chunk_len]
+
+ # Place back to full sequence
+ output_full[0, start_full + j * half_chunk:start_full + (j + 1) * half_chunk] = o0
+ output_full[0, end_full - (j + 1) * half_chunk:end_full - j * half_chunk] = o1
+
+ return output_full
diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py
index 6879fe23bf..594561cdd8 100644
--- a/swift/megatron/trainers/utils.py
+++ b/swift/megatron/trainers/utils.py
@@ -1,16 +1,26 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
-from typing import Any, Dict
+import functools
+import gc
+import time
+from contextlib import contextmanager
+from typing import Any, Dict, Optional
import megatron.core
import torch
+from accelerate.utils import gather as hf_gather
+from accelerate.utils import gather_object as hf_gather_object
from megatron.core import mpu
+from megatron.core.distributed import DistributedDataParallel as DDP
+from megatron.core.optimizer import ChainedOptimizer
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import get_batch_on_this_cp_rank as mcore_get_batch_on_this_cp_rank
-from megatron.training import get_args
+from megatron.training import get_args, get_wandb_writer
from packaging import version
from swift.llm import get_packed_seq_params as _get_packed_seq_params
from swift.llm import to_device
+from swift.utils import get_logger
+from swift.utils.torch_utils import empty_cache, get_current_device
mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
@@ -105,6 +115,7 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]):
keys.append('decoder_input')
else:
keys.append('input_ids')
+
packed_seq_params = batch.get('packed_seq_params')
if packed_seq_params is None:
return mcore_get_batch_on_this_cp_rank(batch)
@@ -117,3 +128,245 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]):
batch[key] = split_cp_inputs(val, packed_seq_params.cu_seqlens_q, -1)
return batch
+
+
+@contextmanager
+def profiling_context(trainer, name: str):
+ start_time = time.perf_counter()
+ yield
+ end_time = time.perf_counter()
+ duration = end_time - start_time
+
+ profiling_metrics = {f'profiling/Time taken: {trainer.__class__.__name__}.{name}': duration}
+ wandb_writer = get_wandb_writer()
+ if wandb_writer and trainer.is_main_process:
+ wandb_writer.log(profiling_metrics)
+
+
+def profiling_decorator(func):
+
+ @functools.wraps(func)
+ def wrapper(self, *args, **kwargs):
+ with profiling_context(self, func.__name__):
+ return func(self, *args, **kwargs)
+
+ return wrapper
+
+
+def gather(tensor, group: Optional[torch.distributed.ProcessGroup] = None):
+ if group is None:
+ return hf_gather(tensor)
+ size = torch.distributed.get_world_size(group=group)
+ output = [torch.empty_like(tensor) for _ in range(size)]
+ torch.distributed.all_gather(output, tensor, group=group, async_op=False)
+
+ return torch.cat(output, dim=0)
+
+
+def gather_object(object: Any, group: Optional[torch.distributed.ProcessGroup] = None):
+ if group is None:
+ return hf_gather_object(object)
+ size = torch.distributed.get_world_size(group=group)
+ output_objects = [None for _ in range(size)]
+ torch.distributed.all_gather_object(output_objects, object, group=group)
+ return [x for y in output_objects for x in y]
+
+
+# code borrowed from verl
+@torch.no_grad()
+def load_megatron_model_to_gpu(models, load_grad=True):
+ for model_chunk in models:
+ if isinstance(model_chunk, DDP):
+ model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers]
+ for buffers in model_chunk_all_buffers:
+ for buffer in buffers:
+ # sometimes, we don't want to load grad for pure inference
+ if load_grad:
+ buffer.grad_data.storage().resize_(buffer.grad_data_size)
+ buffer.grad_data.zero_()
+
+ if buffer.param_data.storage().size() == 0:
+ buffer.param_data.storage().resize_(buffer.param_data_size)
+ # copy data from cpu to cuda
+ buffer.param_data.copy_(buffer.param_data.cpu_data, non_blocking=True)
+ else:
+ # we need this for ref module
+ device_id = get_current_device()
+ for _, param in model_chunk.named_parameters():
+ param.data = param.data.to(device_id, non_blocking=True)
+ if param.grad is not None:
+ param.grad = param.grad.to(device_id, non_blocking=True)
+ gc.collect()
+ empty_cache()
+
+
+@torch.no_grad()
+def offload_megatron_model_to_cpu(models):
+ """
+ In megatron, the model and optimizer storage are:
+ - bf16 parameter data chunked in model parallel group
+ - fp32 grad chunked in model parallel group
+ - fp32 main_parameter chunked in model and dp group
+ - fp32 optimizer state chunked in model and dp group
+ """
+ for model_chunk in models:
+ if isinstance(model_chunk, DDP):
+ model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers]
+ for buffers in model_chunk_all_buffers:
+ for buffer in buffers:
+ # offload parameters
+ if buffer.param_data.storage().size() > 0:
+ buffer.param_data.cpu_data = buffer.param_data.data.cpu().pin_memory()
+ buffer.param_data_size = buffer.param_data.storage().size()
+ buffer.param_data.storage().resize_(0)
+
+ assert buffer.param_data_size == buffer.param_data.cpu_data.storage().size()
+
+ if buffer.grad_data.storage().size() > 0:
+ # if the grad_data size is already zero, we assume that it is already offloaded
+ buffer.grad_data_size = buffer.grad_data.storage().size()
+ buffer.grad_data.storage().resize_(0)
+ else:
+ # we need this for ref module
+ for _, param in model_chunk.named_parameters():
+ param.data = param.data.to('cpu', non_blocking=True)
+ if param.grad is not None:
+ param.grad = param.grad.to('cpu', non_blocking=True)
+ gc.collect()
+ empty_cache()
+
+
+@torch.no_grad()
+def load_megatron_copy_params(optimizers):
+ """
+ Load optimizer parameters back to GPU. Handles ChainedOptimizer.
+
+ Args:
+ optimizers: Optimizer or ChainedOptimizer instance.
+ """
+
+ def _iter_opts(opt):
+ if isinstance(opt, ChainedOptimizer):
+ return opt.chained_optimizers
+ return [opt]
+
+ def load_tensor_to_gpu(tensor):
+ if tensor is None:
+ return
+ device_id = get_current_device()
+ tensor.data = tensor.data.to(device_id, non_blocking=True)
+
+ def load_group_to_gpu(group):
+ if group is None:
+ return
+
+ if isinstance(group, list):
+ for param_group in group:
+ if isinstance(param_group, list):
+ for param in param_group:
+ load_tensor_to_gpu(param)
+ else:
+ load_tensor_to_gpu(param_group)
+ else:
+ load_tensor_to_gpu(group)
+
+ # Load all parameter groups to GPU for each underlying optimizer
+
+ for _opt in _iter_opts(optimizers):
+ if hasattr(_opt, 'shard_fp32_from_float16_groups'):
+ load_group_to_gpu(_opt.shard_fp32_from_float16_groups)
+
+
+@torch.no_grad()
+def offload_megatron_copy_params(optimizers):
+ """
+ Offload optimizer parameters to CPU. Supports both Megatron optimizers
+ and `ChainedOptimizer`, which wraps a list of underlying optimizers.
+
+ Args:
+ optimizers: The optimizer or ChainedOptimizer instance.
+ """
+
+ def _iter_opts(opt):
+ if isinstance(opt, ChainedOptimizer):
+ return opt.chained_optimizers
+ return [opt]
+
+ def offload_tensor_to_cpu(tensor):
+ if tensor is None:
+ return
+ tensor.data = tensor.data.to('cpu', non_blocking=True)
+
+ def offload_group_to_cpu(group):
+ if group is None:
+ return
+
+ if isinstance(group, list):
+ for param_group in group:
+ if isinstance(param_group, list):
+ for param in param_group:
+ offload_tensor_to_cpu(param)
+ else:
+ offload_tensor_to_cpu(param_group)
+ else:
+ offload_tensor_to_cpu(group)
+
+ # Offload all parameter groups to CPU for each underlying optimizer
+
+ for _opt in _iter_opts(optimizers):
+ if hasattr(_opt, 'shard_fp32_from_float16_groups'):
+ offload_group_to_cpu(_opt.shard_fp32_from_float16_groups)
+
+
+@torch.no_grad()
+def load_megatron_optimizer(optimizers):
+
+ def _iter_opts(opt):
+ if isinstance(opt, ChainedOptimizer):
+ return opt.chained_optimizers
+ return [opt]
+
+ for _opt in _iter_opts(optimizers):
+ load_megatron_copy_params(_opt)
+ # if we are using HybridDeviceOptimizer, we need to only move gpu optimizer state to gpu
+ if hasattr(_opt.optimizer, '_move_new_state_to_right_device'):
+ _opt.optimizer._move_new_state_to_right_device()
+ else:
+ opt_state_dict_values = _opt.optimizer.state.values()
+ for v in opt_state_dict_values:
+ if 'exp_avg' in v:
+ v['exp_avg'] = v['exp_avg'].to(get_current_device(), non_blocking=True)
+ if 'exp_avg_sq' in v:
+ v['exp_avg_sq'] = v['exp_avg_sq'].to(get_current_device(), non_blocking=True)
+ gc.collect()
+ empty_cache()
+
+
+@torch.no_grad()
+def offload_megatron_optimizer(optimizers):
+
+ def _iter_opts(opt):
+ if isinstance(opt, ChainedOptimizer):
+ return opt.chained_optimizers
+ return [opt]
+
+ for _opt in _iter_opts(optimizers):
+ offload_megatron_copy_params(_opt)
+ opt_state_dict_values = _opt.optimizer.state.values()
+ for v in opt_state_dict_values:
+ if 'exp_avg' in v:
+ v['exp_avg'] = v['exp_avg'].to('cpu', non_blocking=True)
+ if 'exp_avg_sq' in v:
+ v['exp_avg_sq'] = v['exp_avg_sq'].to('cpu', non_blocking=True)
+ gc.collect()
+ empty_cache()
+
+
+def log_gpu_memory(prefix: str = '', info_once: bool = False):
+ logger = get_logger()
+ log_msg = (f'{prefix} GPU memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, '
+ f'{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved')
+ if info_once:
+ logger.info_once(log_msg, hash_id=prefix)
+ else:
+ logger.info(log_msg)
diff --git a/swift/megatron/tuners/lora.py b/swift/megatron/tuners/lora.py
index 815fa63d5c..23fcd2b107 100644
--- a/swift/megatron/tuners/lora.py
+++ b/swift/megatron/tuners/lora.py
@@ -428,6 +428,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
def unmerge(self) -> None:
"""
Unmerge all merged adapter weights from the base weights.
+
This method reverses the merge operation by subtracting the LoRA delta weights
from the base layer weights, restoring the original base weights.
"""
diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py
index 968f5cedf5..42f6afdcdd 100644
--- a/swift/trainers/arguments.py
+++ b/swift/trainers/arguments.py
@@ -324,7 +324,7 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin):
# Beyond the 80/20 Rule, https://arxiv.org/abs/2506.01939
top_entropy_quantile: float = 1.0
- # GSPO https://www.arxiv.org/abs/2507.18071
+ # GSPO https://arxiv.org/abs/2507.18071
importance_sampling_level: Literal['token', 'sequence', 'sequence_token'] = 'token'
# RLOO, REINFORCE++
diff --git a/swift/trainers/rlhf_trainer/__init__.py b/swift/trainers/rlhf_trainer/__init__.py
index 8830dbac20..829dba091b 100644
--- a/swift/trainers/rlhf_trainer/__init__.py
+++ b/swift/trainers/rlhf_trainer/__init__.py
@@ -14,6 +14,7 @@
from .gkd_trainer import GKDTrainer
from .rlhf_mixin import RLHFTrainerMixin
from .utils import patch_lora_merge, patch_lora_unmerge, round_robin, _ForwardRedirection
+ from .vllm_client import VLLMClient
else:
_import_structure = {
'cpo_trainer': ['CPOTrainer'],
@@ -26,6 +27,7 @@
'gkd_trainer': ['GKDTrainer'],
'rlhf_mixin': ['RLHFTrainerMixin'],
'utils': ['patch_lora_merge', 'patch_lora_unmerge', 'round_robin', '_ForwardRedirection'],
+ 'vllm_client': ['VLLMClient'],
}
import sys
diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py
index 1b81ffaeb0..53cc4b5c99 100644
--- a/swift/trainers/rlhf_trainer/grpo_trainer.py
+++ b/swift/trainers/rlhf_trainer/grpo_trainer.py
@@ -1861,7 +1861,7 @@ def _prepare_algorithm_params(self):
# Entropy Mask, https://arxiv.org/abs/2506.01939
self.top_entropy_quantile = args.top_entropy_quantile
- # GSPO, https://www.arxiv.org/abs/2507.18071
+ # GSPO, https://arxiv.org/abs/2507.18071
self.importance_sampling_level = args.importance_sampling_level
# RLOO,
diff --git a/swift/trainers/rlhf_trainer/rollout_mixin.py b/swift/trainers/rlhf_trainer/rollout_mixin.py
index 17d8210021..3cb154a94c 100644
--- a/swift/trainers/rlhf_trainer/rollout_mixin.py
+++ b/swift/trainers/rlhf_trainer/rollout_mixin.py
@@ -637,6 +637,7 @@ def _fast_infer(self, inputs: DataType) -> DataType:
if self.engine.inner_model_executor.is_sleeping:
wake_up_params = inspect.signature(self.engine.engine.wake_up).parameters
kwargs = {'tags': ['weights']} if 'tags' in wake_up_params else {}
+ aggressive_empty_cache()
self.engine.engine.wake_up(**kwargs)
if self.state.global_step != self._last_loaded_step:
diff --git a/swift/trainers/rlhf_trainer/vllm_client.py b/swift/trainers/rlhf_trainer/vllm_client.py
index 81e614a89c..2de38550d6 100644
--- a/swift/trainers/rlhf_trainer/vllm_client.py
+++ b/swift/trainers/rlhf_trainer/vllm_client.py
@@ -133,9 +133,14 @@ def infer(
results = [None] * self.num_servers
errors = [None] * self.num_servers
+ if isinstance(request_config, RequestConfig):
+ request_config = asdict(request_config)
def process_chunk(i, chunk):
try:
+ if len(chunk) > 0 and isinstance(chunk[0], RolloutInferRequest):
+ chunk = [asdict(req) for req in chunk]
+
response = self.sessions[i].post(
f'{self.base_urls[i]}/infer/',
json={
@@ -208,7 +213,7 @@ def init_communicator(self, device: Union[int, str] = 0):
pg = StatelessProcessGroup.create(
host=self.hosts[i], port=self.group_ports[i], rank=rank, world_size=world_size)
- comm = PyNcclCommunicator(pg, device=0)
+ comm = PyNcclCommunicator(pg, device=device)
self.pynccl_comms.append(comm)
atexit.register(self.close_communicator)