From 4316425d539a7e351ff42081eaf41394815472a5 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 29 Aug 2025 18:07:26 +0800 Subject: [PATCH 01/83] wip --- swift/megatron/argument/megatron_args.py | 138 +++++++++++++++++++++- swift/megatron/argument/rlhf_args.py | 1 - swift/megatron/train/rlhf.py | 4 +- swift/megatron/trainers/__init__.py | 1 + swift/megatron/trainers/dpo_trainer.py | 91 +------------- swift/megatron/trainers/grpo_trainer.py | 144 +++++++++++++++++++++++ swift/megatron/trainers/rlhf_base.py | 109 +++++++++++++++++ 7 files changed, 396 insertions(+), 92 deletions(-) create mode 100644 swift/megatron/trainers/grpo_trainer.py create mode 100644 swift/megatron/trainers/rlhf_base.py diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 9628b4da6a..fbc7db7db1 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -10,13 +10,14 @@ from transformers.utils.versions import require_version from swift.llm.argument.base_args import to_abspath -from swift.utils import get_dist_setting, get_logger, json_parse_to_dict +from swift.utils import get_current_device, get_dist_setting, get_logger, is_master, json_parse_to_dict logger = get_logger() @dataclass class RLHFMegatronArgumentsMixin: + rlhf_type: Literal['dpo', 'grpo'] = 'dpo' ref_load: Optional[str] = None ref_adapter_load: Optional[str] = None @@ -27,6 +28,141 @@ class RLHFMegatronArgumentsMixin: f_divergence_type: str = 'reverse_kl' loss_type: str = 'sigmoid' + # =========================== GRPO =========================== + # ─────────────────────────── Sampling ─────────────────────────── + epsilon: float = 0.2 + epsilon_high: Optional[float] = None + delta: Optional[float] = None + top_k: int = 50 + top_p: float = 0.9 + repetition_penalty: float = 1. + # ─────────────────────────── VLLM ─────────────────────────── + use_vllm: bool = False + vllm_mode: Literal['server', 'colocate'] = 'colocate' + # ────────────── Internal VLLM (colocate) ────────────── + vllm_enable_prefix_caching: bool = True + vllm_gpu_memory_utilization: float = 0.9 + vllm_tensor_parallel_size: int = 1 + vllm_max_model_len: Optional[int] = None + vllm_enforce_eager: bool = False + vllm_limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 5, "video": 2}' + vllm_disable_cascade_attn: bool = False + sleep_level: Literal[0, 1, 2] = 0 + + # ────────────── External VLLM (server) ────────────── + vllm_server_base_url: Optional[List[str]] = None + vllm_server_host: Optional[List[str]] = None + vllm_server_port: List[int] = field(default_factory=lambda: [8000]) + vllm_server_timeout: float = 240.0 + vllm_client: Optional[object] = field(init=False, default=None) + + # ─────────────────────────── Reward ─────────────────────────── + # see details in swift/plugin/orm.py + # cosine reward, https://arxiv.org/abs/2502.03373 + cosine_min_len_value_wrong: float = -0.5 # r^w_0 in paper, Reward for wrong answers with zero completion length. + cosine_max_len_value_wrong: float = 0.0 # r^w_L in paper, Reward for wrong answers with max completion length. + cosine_min_len_value_correct: float = 1.0 # r^c_0 in paper, Reward for correct answers with zero completion length. + cosine_max_len_value_correct: float = 0.5 # r^c_L in paper, Reward for correct answers with max completion length. + cosine_max_len: Optional[int] = None # Lmax in paper, default equal to max_completion_length + # repetition penalty, https://arxiv.org/abs/2502.03373 + repetition_n_grams: int = 3 + repetition_max_penalty: float = -1.0 + # soft_overlong, https://arxiv.org/abs/2503.14476 + soft_max_length: Optional[int] = None + soft_cache_length: Optional[int] = None + + reward_model: Optional[List[str]] = None + reward_model_plugin: Optional[List[str]] = None + + # ─────────────────────────── Not Supported Yet ─────────────────────────── + # sync ref model + sync_ref_model: bool = False + ref_model_sync_steps: int = 512 + ref_model_mixup_alpha: float = 0.6 + + async_generate: bool = False + + move_model_batches: Optional[int] = None + offload_optimizer: bool = False + offload_model: bool = False + gc_collect_after_offload: bool = False # deprecated + + # multi turn + multi_turn_func: Optional[str] = None # deprecated + multi_turn_scheduler: Optional[str] = None + max_turns: Optional[int] = None + completion_length_limit_scope: Literal['total', 'per_round'] = 'per_round' + vllm_server_pass_dataset: bool = False + + # DAPO, https://arxiv.org/abs/2503.14476 + dynamic_sample: bool = False + max_resample_times: int = 3 + overlong_filter: bool = False + + # Dr. GRPO, https://arxiv.org/abs/2503.20783 + scale_rewards: bool = True + + # entropy + log_entropy: bool = False + # Beyond the 80/20 Rule, https://arxiv.org/abs/2506.01939 + top_entropy_quantile: float = 1.0 + + # GSPO https://www.arxiv.org/abs/2507.18071 + importance_sampling_level: Literal['token', 'sequence', 'sequence_token'] = 'token' + + wandb_log_unique_prompts: Optional[bool] = None + generation_batch_size: Optional[int] = None + steps_per_generation: Optional[int] = None + + num_iterations: int = 1 + + # dataset + dataset_shuffle: Optional[bool] = True + + def __post_init__(self): + if self.rlhf_type == 'grpo': + self._init_grpo() + super().__post_init__() + if self.rlhf_type == 'grpo': + self._set_grpo_default() + + def _set_grpo_default(self): + if self.beta is None: + self.beta = 0.04 # https://arxiv.org/abs/2402.03300 + + def _init_grpo(self): + + def _init_external_vllm(): + if self.rlhf_type != 'grpo' or (self.vllm_server_host is None and self.vllm_server_base_url is None): + return + from swift.trainers.rlhf_trainer.vllm_client import VLLMClient + if is_master(): + logger.info('Start connecting to vLLM server') + self.vllm_client = VLLMClient( + base_urls=self.vllm_server_base_url, + hosts=self.vllm_server_host, + server_ports=self.vllm_server_port, + connection_timeout=self.vllm_server_timeout) + self.vllm_client.init_communicator(device=get_current_device()) + logger.info('Connected to vLLM server') + + def _check_not_supported(): + # TODO: check + # bool + not_supported_args = [ + 'sync_ref_model', + 'async_generate', + ] + for arg in not_supported_args: + if getattr(self, arg): + raise ValueError(f'{arg} is not supported for Megatron-GRPO yet, please unset it.') + # else + if self.num_iterations > 1: + raise ValueError('num_iterations > 1 is not supported for Megatron-GRPO yet, please set it to 1.') + + _init_external_vllm() + _check_not_supported() + @dataclass class MegatronTunerMixin: diff --git a/swift/megatron/argument/rlhf_args.py b/swift/megatron/argument/rlhf_args.py index 2119c54dcd..513e26c2ab 100644 --- a/swift/megatron/argument/rlhf_args.py +++ b/swift/megatron/argument/rlhf_args.py @@ -7,7 +7,6 @@ @dataclass class MegatronRLHFArguments(MegatronTrainArguments): - rlhf_type: Literal['dpo'] = 'dpo' loss_scale: str = 'last_round' calculate_per_token_loss: bool = False diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index 1e36f27cbc..c8f3129b4b 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -3,7 +3,7 @@ from swift.utils import get_logger from ..argument import MegatronRLHFArguments -from ..trainers import MegatronDPOTrainer +from ..trainers import MegatronDPOTrainer, MegatronGRPOTrainer from .sft import MegatronSft logger = get_logger() @@ -17,6 +17,8 @@ def prepare_trainer(self): args = self.args if args.rlhf_type == 'dpo': trainer_cls = MegatronDPOTrainer + elif args.rlhf_type == 'grpo': + trainer_cls = MegatronGRPOTrainer else: raise ValueError(f'The current Megatron-SWIFT does not support rlhf_type: {args.rlhf_type}.') return trainer_cls(args) diff --git a/swift/megatron/trainers/__init__.py b/swift/megatron/trainers/__init__.py index c891081541..835b984590 100644 --- a/swift/megatron/trainers/__init__.py +++ b/swift/megatron/trainers/__init__.py @@ -1,3 +1,4 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .dpo_trainer import MegatronDPOTrainer +from .grpo_trainer import MegatronGRPOTrainer from .trainer import MegatronTrainer diff --git a/swift/megatron/trainers/dpo_trainer.py b/swift/megatron/trainers/dpo_trainer.py index 7798de2b08..5df60605c8 100644 --- a/swift/megatron/trainers/dpo_trainer.py +++ b/swift/megatron/trainers/dpo_trainer.py @@ -13,6 +13,7 @@ from swift.trainers import DPOTrainer from swift.utils import get_current_device, get_logger +from .rlhf_base import MegatronRLHFTrainer from .trainer import MegatronTrainer from .utils import get_batch @@ -33,81 +34,12 @@ def __init__(self, args): self.beta = args.beta -class MegatronDPOTrainer(MegatronTrainer): +class MegatronDPOTrainer(MegatronRLHFTrainer): def __init__(self, args): super().__init__(args) self.dummy_dpo_trainer = DummyDPOTrainer(args) - def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs): - args = get_args() - if args.train_type == 'full': - ref_model = get_model(model_provider_func, model_type) - if args.ref_load is None: - args.ref_load = args.load - args.iteration, args.num_floating_point_operations_so_far = load_checkpoint( - ref_model, None, None, load_arg='ref_load') - self.ref_model = ref_model[0] - self.ref_model.eval() - else: - self.ref_model = None - return super().setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs) - - @staticmethod - def _forward_step_helper(model, inputs): - args = get_args() - if mpu.is_pipeline_first_stage(): - micro_batch_size = 1 # use qkv_format 'thd' - seq_length = inputs['input_ids'].shape[1] - if args.sequence_parallel: - seq_length //= mpu.get_tensor_model_parallel_world_size() - recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size], - device=torch.cuda.current_device(), - dtype=torch.int64) - else: - recv_shape_buffer = torch.empty((3, ), device=torch.cuda.current_device(), dtype=torch.int64) - recv_from_prev_pipeline_rank_(recv_shape_buffer) - if not mpu.is_pipeline_last_stage(): - send_to_next_pipeline_rank(recv_shape_buffer) - shape = recv_shape_buffer.tolist() - - if not mpu.is_pipeline_first_stage(): - recv_buffer = torch.empty(shape, device=torch.cuda.current_device(), dtype=args.params_dtype) - recv_from_prev_pipeline_rank_(recv_buffer) - model.set_input_tensor(recv_buffer) - output_tensor = model(**inputs) - if not mpu.is_pipeline_last_stage(): - send_to_next_pipeline_rank(output_tensor) - output_tensor = None - - return output_tensor - - def ref_forward(self, ref_model, data_iterator): - with self.stimer(bdata=True): - data = get_batch(data_iterator) - data.pop('loss_scale', None) - labels = data.get('labels') - with torch.no_grad(): - output_tensor = self._forward_step_helper(ref_model, data) - data['logps'] = None if labels is None else self.get_logps(output_tensor, labels, data['packed_seq_params']) - return data - - @staticmethod - def get_logps(output_tensor, labels, packed_seq_params): - args = get_args() - per_token_logps = -output_tensor - loss_mask = labels != -100 - per_token_logps = per_token_logps * loss_mask - num_samples = packed_seq_params.num_samples - cu_seqlens = packed_seq_params.cu_seqlens_q[:num_samples * 2 + 1] // args.context_parallel_size - all_logps = per_token_logps.new_zeros((num_samples * 2, )) - for i in range(num_samples * 2): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - all_logps[i] = per_token_logps[:, start:end].sum() - if args.context_parallel_size > 1: - all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) - return all_logps - def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, labels: torch.Tensor, packed_seq_params): args = get_args() @@ -150,25 +82,6 @@ def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, lab loss = loss / mpu.get_context_parallel_world_size() return loss, reporting_metric - @contextmanager - def null_ref_context(self): - args = get_args() - if args.train_type == 'full': - context = nullcontext() - ref_model = unwrap_model(self.ref_model) - else: - if args.ref_adapter_load is None: - context = self.peft_model.disable_adapter() - else: - context = nullcontext() - ref_model = self.unwrapped_model - with context: - if args.ref_adapter_load: - self.peft_model.set_adapter('ref_adapter') - yield ref_model - if args.ref_adapter_load: - self.peft_model.set_adapter('default') - def _replace_data_iterator(self, data_iterator): args = get_args() num_iters_per_step = args.global_batch_size // (args.micro_batch_size * mpu.get_data_parallel_world_size()) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py new file mode 100644 index 0000000000..22c57bc9a5 --- /dev/null +++ b/swift/megatron/trainers/grpo_trainer.py @@ -0,0 +1,144 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from collections import namedtuple +from contextlib import contextmanager, nullcontext +from functools import partial + +import torch +from megatron.core import mpu +from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank +from megatron.training import get_args, get_model, training +from megatron.training.checkpointing import load_checkpoint +from megatron.training.utils import unwrap_model +from torch.distributed.nn import all_reduce + +from swift.utils import get_current_device, get_logger, is_vllm_available +from ..argument import MegatronRLHFArguments +from .rlhf_base import MegatronRLHFTrainer +from .trainer import MegatronTrainer +from .utils import get_batch + +logger = get_logger() + + +class MegatronGRPOTrainer(MegatronRLHFTrainer): + + def __init__(self, args: MegatronRLHFArguments): + super().__init__(args) + # TODO: init vllm + self.use_vllm = args.use_vllm + vllm_client = args.vllm_client + if self.use_vllm: + if not is_vllm_available(): + raise ImportError('vLLM is not available and `use_vllm` is set to True. ' + 'Please install vLLM with `pip install vllm -U` to use it.') + if self.vllm_mode == 'server': + self.vllm_client: VLLMClient = vllm_client + if self.accelerator.is_main_process: + self.vllm_client.get_engine_type() + vllm_use_async_engine = [self.vllm_client.use_async_engine] + use_gym_env = [self.vllm_client.use_gym_env] + enable_multi_turn = [self.vllm_client.enable_multi_turn] + else: + vllm_use_async_engine = [False] + use_gym_env = [False] + enable_multi_turn = [self.enable_server_multi_turn] + self.vllm_use_async_engine = broadcast_object_list(vllm_use_async_engine, from_process=0)[0] + self.use_gym_env = broadcast_object_list(use_gym_env, from_process=0)[0] + self.enable_server_multi_turn = broadcast_object_list(enable_multi_turn, from_process=0)[0] + if self.use_gym_env: + self.reward_func_names = ['gym_reward'] + + elif self.vllm_mode == 'colocate': + if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0: + raise ValueError( + f'vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size ' + f'({self.accelerator.num_processes}) evenly.') + + if self.vllm_tensor_parallel_size > 1: + # Create subgroups of ranks for TP, each group with `vllm_tensor_parallel_size` ranks. + # For example, if world_size=8 and vllm_tensor_parallel_size=2 → groups: [0,1], [2,3], [4,5], [6,7] + self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration([ + list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size)) + for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size) + ]) + self.enable_offload = self.args.offload_model or self.args.offload_optimizer + context = self.offload_context if self.enable_offload else nullcontext + + with context(): + self.engine = self.prepare_vllm(model) + if self.args.sleep_level > 0: + self.engine.engine.sleep(self.args.sleep_level) + + else: + from swift.llm import PtEngine + self.engine = PtEngine.from_model_template(self.model, self.template, max_batch_size=0) # 0: no limit + + def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): + # prepare global batch data here + new_data_iterator = self._replace_data_iterator(data_iterator) + return self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler, + config) + + def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, labels: torch.Tensor, + packed_seq_params): + args = get_args() + num_samples = packed_seq_params.num_samples + + logps = self.get_logps(output_tensor, labels, packed_seq_params) + loss, chosen_rewards, rejected_rewards = self.dummy_dpo_trainer.dpo_loss( + logps[:num_samples], + logps[num_samples:], + ref_logps[:num_samples], + ref_logps[num_samples:], + ) + if args.rpo_alpha: + loss_mask = labels != -100 + num_tokens = packed_seq_params.cu_seqlens_q[num_samples] // args.context_parallel_size + loss_mask[:, num_tokens:] = 0 + nll_loss = torch.concat([torch.sum(output_tensor * loss_mask)[None], loss_mask.sum()[None]]) + if args.context_parallel_size > 1: + nll_loss = all_reduce(nll_loss, group=mpu.get_context_parallel_group()) + nll_loss = nll_loss[0] / nll_loss[1] + loss = loss + args.rpo_alpha * nll_loss + loss = loss.mean() + metric = { + 'loss': loss.clone().detach(), + 'logps/chosen': logps[:num_samples].mean(), + 'logps/rejected': logps[num_samples:].mean(), + 'rewards/chosen': chosen_rewards.mean(), + 'rewards/rejected': rejected_rewards.mean(), + 'rewards/accuracies': (chosen_rewards > rejected_rewards).float().mean(), + 'rewards/margins': (chosen_rewards - rejected_rewards).mean(), + } + if args.rpo_alpha: + metric['nll_loss'] = nll_loss.detach() + reporting_metric = loss.new_tensor(list(metric.values())) + torch.distributed.all_reduce( + reporting_metric, torch.distributed.ReduceOp.AVG, group=mpu.get_data_parallel_group()) + reporting_metric = {k: reporting_metric[i] for i, k in enumerate(metric.keys())} + # fix megatron-lm bug + # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.12.0/megatron/core/pipeline_parallel/schedules.py#L291 + loss = loss / mpu.get_context_parallel_world_size() + return loss, reporting_metric + + def _replace_data_iterator(self, data_iterator): + args = get_args() + num_iters_per_step = args.global_batch_size // (args.micro_batch_size * mpu.get_data_parallel_world_size()) + res = [] + with torch.no_grad(), self.null_ref_context() as ref_model: + for i in range(num_iters_per_step): + res.append(self.ref_forward(ref_model, data_iterator)) + return iter(res) + + def forward_step(self, data_iterator, model): + with torch.no_grad(): + data = next(data_iterator) + + ref_logps = data.pop('logps') + with self.stimer: + output_tensor = model(**data) + return output_tensor, partial( + self.loss_func, + ref_logps=ref_logps, + labels=data.get('labels'), + packed_seq_params=data.get('packed_seq_params')) diff --git a/swift/megatron/trainers/rlhf_base.py b/swift/megatron/trainers/rlhf_base.py new file mode 100644 index 0000000000..26064df297 --- /dev/null +++ b/swift/megatron/trainers/rlhf_base.py @@ -0,0 +1,109 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from collections import namedtuple +from contextlib import contextmanager, nullcontext +from functools import partial + +import torch +from megatron.core import mpu +from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank +from megatron.training import get_args, get_model, training +from megatron.training.checkpointing import load_checkpoint +from megatron.training.utils import unwrap_model +from torch.distributed.nn import all_reduce + +from swift.utils import get_current_device, get_logger +from .trainer import MegatronTrainer +from .utils import get_batch + +logger = get_logger() + + +class MegatronRLHFTrainer(MegatronTrainer): + + def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs): + args = get_args() + if args.train_type == 'full': + ref_model = get_model(model_provider_func, model_type) + if args.ref_load is None: + args.ref_load = args.load + args.iteration, args.num_floating_point_operations_so_far = load_checkpoint( + ref_model, None, None, load_arg='ref_load') + self.ref_model = ref_model[0] + self.ref_model.eval() + else: + self.ref_model = None + return super().setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs) + + @staticmethod + def _forward_step_helper(model, inputs): + args = get_args() + if mpu.is_pipeline_first_stage(): + micro_batch_size = 1 # use qkv_format 'thd' + seq_length = inputs['input_ids'].shape[1] + if args.sequence_parallel: + seq_length //= mpu.get_tensor_model_parallel_world_size() + recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size], + device=torch.cuda.current_device(), + dtype=torch.int64) + else: + recv_shape_buffer = torch.empty((3, ), device=torch.cuda.current_device(), dtype=torch.int64) + recv_from_prev_pipeline_rank_(recv_shape_buffer) + if not mpu.is_pipeline_last_stage(): + send_to_next_pipeline_rank(recv_shape_buffer) + shape = recv_shape_buffer.tolist() + + if not mpu.is_pipeline_first_stage(): + recv_buffer = torch.empty(shape, device=torch.cuda.current_device(), dtype=args.params_dtype) + recv_from_prev_pipeline_rank_(recv_buffer) + model.set_input_tensor(recv_buffer) + output_tensor = model(**inputs) + if not mpu.is_pipeline_last_stage(): + send_to_next_pipeline_rank(output_tensor) + output_tensor = None + + return output_tensor + + def ref_forward(self, ref_model, data_iterator): + with self.stimer(bdata=True): + data = get_batch(data_iterator) + data.pop('loss_scale', None) + labels = data.get('labels') + with torch.no_grad(): + output_tensor = self._forward_step_helper(ref_model, data) + data['logps'] = None if labels is None else self.get_logps(output_tensor, labels, data['packed_seq_params']) + return data + + @staticmethod + def get_logps(output_tensor, labels, packed_seq_params): + args = get_args() + per_token_logps = -output_tensor + loss_mask = labels != -100 + per_token_logps = per_token_logps * loss_mask + num_samples = packed_seq_params.num_samples + cu_seqlens = packed_seq_params.cu_seqlens_q[:num_samples * 2 + 1] // args.context_parallel_size + all_logps = per_token_logps.new_zeros((num_samples * 2, )) + for i in range(num_samples * 2): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + all_logps[i] = per_token_logps[:, start:end].sum() + if args.context_parallel_size > 1: + all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) + return all_logps + + @contextmanager + def null_ref_context(self): + args = get_args() + if args.train_type == 'full': + context = nullcontext() + ref_model = unwrap_model(self.ref_model) + else: + if args.ref_adapter_load is None: + context = self.peft_model.disable_adapter() + else: + context = nullcontext() + ref_model = self.unwrapped_model + with context: + if args.ref_adapter_load: + self.peft_model.set_adapter('ref_adapter') + yield ref_model + if args.ref_adapter_load: + self.peft_model.set_adapter('default') From 5d46eae61718b5d61446d11a45250fd1e9079eb8 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 1 Sep 2025 14:47:52 +0800 Subject: [PATCH 02/83] init wip --- swift/megatron/argument/megatron_args.py | 40 +++++++++++++++++++ swift/megatron/trainers/grpo_trainer.py | 51 +++++++++++++++++++++--- swift/trainers/rlhf_trainer/__init__.py | 2 + 3 files changed, 88 insertions(+), 5 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index fbc7db7db1..2a0f7522fd 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -57,6 +57,8 @@ class RLHFMegatronArgumentsMixin: vllm_client: Optional[object] = field(init=False, default=None) # ─────────────────────────── Reward ─────────────────────────── + reward_funcs: List[str] = field(default_factory=list) + reward_weights: List[float] = None # see details in swift/plugin/orm.py # cosine reward, https://arxiv.org/abs/2502.03373 cosine_min_len_value_wrong: float = -0.5 # r^w_0 in paper, Reward for wrong answers with zero completion length. @@ -75,6 +77,9 @@ class RLHFMegatronArgumentsMixin: reward_model_plugin: Optional[List[str]] = None # ─────────────────────────── Not Supported Yet ─────────────────────────── + # reward model + reward_model: Optional[List[str]] = None + reward_model_plugin: Optional[List[str]] = None # sync ref model sync_ref_model: bool = False ref_model_sync_steps: int = 512 @@ -162,6 +167,41 @@ def _check_not_supported(): _init_external_vllm() _check_not_supported() + if self.use_vllm: + set_default_ddp_config() + if self.async_generate or not self.use_vllm: + self.sleep_level = 0 + self.remove_unused_columns = False + logger.info(f'Setting args.remove_unused_columns: {self.remove_unused_columns}') + if self.truncation_strategy is None: + self.truncation_strategy = 'left' + assert self.truncation_strategy in ['left', 'delete' + ], ("GRPO requires `truncation_strategy 'left' or 'delete'`, " + f"Current value: `truncation_strategy='{self.truncation_strategy}'`." + ) # noqa + if self.beta is None: + self.beta = 0.04 # https://arxiv.org/abs/2402.03300 + if self.async_generate: + logger.info('Using async mode. This is a approximate version which ' + 'will use the old weights to generate responses to accelerate. ' + 'This will ignore the `CLIP` of advantages, if you found the training ' + 'is unstable, you may consider using --async_generate false.') + if 'soft_overlong' in self.reward_funcs: + assert self.soft_cache_length is not None, \ + 'The soft_cache_length must be set when using soft overlong rewards.' + if self.soft_max_length is None: + self.soft_max_length = self.max_completion_length + logger.info(f'Auto-configured soft_max_length = max_completion_length {self.max_completion_length}') + if self.use_vllm: + # set vllm mode + if self.vllm_server_host is not None or self.vllm_server_base_url is not None: + if self.vllm_mode != 'server': + self.vllm_mode = 'server' + logger.warning('set vllm_mode to `server` since vllm server host/base_url is provided') + else: + if self.vllm_mode != 'colocate': + self.vllm_mode = 'colocate' + logger.warning('set vllm_mode to `colocate` since vllm_server_host is not provided') @dataclass diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 22c57bc9a5..0d9f7fbf2a 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -1,9 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import inspect from collections import namedtuple from contextlib import contextmanager, nullcontext from functools import partial import torch +from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed from megatron.core import mpu from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank from megatron.training import get_args, get_model, training @@ -11,6 +13,8 @@ from megatron.training.utils import unwrap_model from torch.distributed.nn import all_reduce +from swift.plugin import orms +from swift.trainers.rlhf_trainer import GRPOTrainer, VLLMClient from swift.utils import get_current_device, get_logger, is_vllm_available from ..argument import MegatronRLHFArguments from .rlhf_base import MegatronRLHFTrainer @@ -20,11 +24,48 @@ logger = get_logger() -class MegatronGRPOTrainer(MegatronRLHFTrainer): +class MegatronGRPOTrainer(MegatronRLHFTrainer, GRPOTrainer): def __init__(self, args: MegatronRLHFArguments): - super().__init__(args) + MegatronRLHFTrainer().__init__(args) # TODO: init vllm + self.args = args + self.processing_class = self.processor + reward_funcs = args.reward_funcs + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + if reward_funcs: + for i, reward_func in enumerate(reward_funcs): + if reward_func in orms: + reward_func_class = orms[reward_func] + reward_func_args = list(inspect.signature(reward_func_class.__init__).parameters) + reward_func_kwargs = { + key: getattr(args, key) + for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key) + } + if 'tokenizer' in reward_func_args: + reward_func_kwargs['tokenizer'] = self.processing_class + reward_funcs[i] = reward_func_class(**reward_func_kwargs) + elif not callable(reward_func): + raise ValueError(f'reward_function {reward_func} is not implemented in swift.plugin') + self.reward_funcs = reward_funcs + self.reward_func_names = [] + for reward_func in reward_funcs: + if inspect.isfunction(reward_func): + reward_func_name = reward_func.__name__ + else: + reward_func_name = reward_func.__class__.__name__ + self.reward_func_names.append(reward_func_name) + # TODO: reward model + # TODO: multi turn scheduler(colocate multi turn) + + self.num_generations = args.num_generations + self.temperature = args.temperature + self.vllm_mode = args.vllm_mode + self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode + self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode + self.loss_type = args.loss_type + self.max_completion_length = args.max_completion_length self.use_vllm = args.use_vllm vllm_client = args.vllm_client if self.use_vllm: @@ -61,13 +102,13 @@ def __init__(self, args: MegatronRLHFArguments): list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size)) for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size) ]) - self.enable_offload = self.args.offload_model or self.args.offload_optimizer + self.enable_offload = args.offload_model or args.offload_optimizer context = self.offload_context if self.enable_offload else nullcontext with context(): - self.engine = self.prepare_vllm(model) + self.engine = self.prepare_vllm(self.unwrapped_model) if self.args.sleep_level > 0: - self.engine.engine.sleep(self.args.sleep_level) + self.engine.engine.sleep(args.sleep_level) else: from swift.llm import PtEngine diff --git a/swift/trainers/rlhf_trainer/__init__.py b/swift/trainers/rlhf_trainer/__init__.py index 8830dbac20..829dba091b 100644 --- a/swift/trainers/rlhf_trainer/__init__.py +++ b/swift/trainers/rlhf_trainer/__init__.py @@ -14,6 +14,7 @@ from .gkd_trainer import GKDTrainer from .rlhf_mixin import RLHFTrainerMixin from .utils import patch_lora_merge, patch_lora_unmerge, round_robin, _ForwardRedirection + from .vllm_client import VLLMClient else: _import_structure = { 'cpo_trainer': ['CPOTrainer'], @@ -26,6 +27,7 @@ 'gkd_trainer': ['GKDTrainer'], 'rlhf_mixin': ['RLHFTrainerMixin'], 'utils': ['patch_lora_merge', 'patch_lora_unmerge', 'round_robin', '_ForwardRedirection'], + 'vllm_client': ['VLLMClient'], } import sys From 582822940cc0eaf1a10dbea70d19985bfbbf0a1c Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 1 Sep 2025 16:36:05 +0800 Subject: [PATCH 03/83] args wip --- swift/megatron/argument/megatron_args.py | 6 ++++-- swift/megatron/trainers/grpo_trainer.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 2a0f7522fd..2ee3a1c7ac 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -29,6 +29,10 @@ class RLHFMegatronArgumentsMixin: loss_type: str = 'sigmoid' # =========================== GRPO =========================== + num_generations: int = 8 + mini_batch_size: int = 4 + max_completion_length: int = 512 + # ─────────────────────────── Sampling ─────────────────────────── epsilon: float = 0.2 epsilon_high: Optional[float] = None @@ -167,8 +171,6 @@ def _check_not_supported(): _init_external_vllm() _check_not_supported() - if self.use_vllm: - set_default_ddp_config() if self.async_generate or not self.use_vllm: self.sleep_level = 0 self.remove_unused_columns = False diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 0d9f7fbf2a..bd883da631 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -59,7 +59,8 @@ def __init__(self, args: MegatronRLHFArguments): # TODO: reward model # TODO: multi turn scheduler(colocate multi turn) - self.num_generations = args.num_generations + self.num_generations = args.num_generations # G in the GRPO paper + self.mini_batch_size = args.mini_batch_size self.temperature = args.temperature self.vllm_mode = args.vllm_mode self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode From 0689b7622d589454d3ec96720b4b41c2348a1623 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 3 Sep 2025 11:22:41 +0800 Subject: [PATCH 04/83] reuse _prepare_rollout_engine --- swift/megatron/trainers/grpo_trainer.py | 67 +++-------- swift/trainers/rlhf_trainer/grpo_trainer.py | 125 ++++++++++---------- 2 files changed, 77 insertions(+), 115 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index bd883da631..8297ac42f5 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -27,10 +27,11 @@ class MegatronGRPOTrainer(MegatronRLHFTrainer, GRPOTrainer): def __init__(self, args: MegatronRLHFArguments): - MegatronRLHFTrainer().__init__(args) + MegatronRLHFTrainer().__init__(self, args) # TODO: init vllm self.args = args self.processing_class = self.processor + # set up reward funcs reward_funcs = args.reward_funcs if not isinstance(reward_funcs, list): reward_funcs = [reward_funcs] @@ -58,62 +59,21 @@ def __init__(self, args: MegatronRLHFArguments): self.reward_func_names.append(reward_func_name) # TODO: reward model # TODO: multi turn scheduler(colocate multi turn) - + self._prepare_rollout_engine self.num_generations = args.num_generations # G in the GRPO paper - self.mini_batch_size = args.mini_batch_size self.temperature = args.temperature - self.vllm_mode = args.vllm_mode - self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode - self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode self.loss_type = args.loss_type self.max_completion_length = args.max_completion_length - self.use_vllm = args.use_vllm - vllm_client = args.vllm_client - if self.use_vllm: - if not is_vllm_available(): - raise ImportError('vLLM is not available and `use_vllm` is set to True. ' - 'Please install vLLM with `pip install vllm -U` to use it.') - if self.vllm_mode == 'server': - self.vllm_client: VLLMClient = vllm_client - if self.accelerator.is_main_process: - self.vllm_client.get_engine_type() - vllm_use_async_engine = [self.vllm_client.use_async_engine] - use_gym_env = [self.vllm_client.use_gym_env] - enable_multi_turn = [self.vllm_client.enable_multi_turn] - else: - vllm_use_async_engine = [False] - use_gym_env = [False] - enable_multi_turn = [self.enable_server_multi_turn] - self.vllm_use_async_engine = broadcast_object_list(vllm_use_async_engine, from_process=0)[0] - self.use_gym_env = broadcast_object_list(use_gym_env, from_process=0)[0] - self.enable_server_multi_turn = broadcast_object_list(enable_multi_turn, from_process=0)[0] - if self.use_gym_env: - self.reward_func_names = ['gym_reward'] - - elif self.vllm_mode == 'colocate': - if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0: - raise ValueError( - f'vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size ' - f'({self.accelerator.num_processes}) evenly.') - - if self.vllm_tensor_parallel_size > 1: - # Create subgroups of ranks for TP, each group with `vllm_tensor_parallel_size` ranks. - # For example, if world_size=8 and vllm_tensor_parallel_size=2 → groups: [0,1], [2,3], [4,5], [6,7] - self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration([ - list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size)) - for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size) - ]) - self.enable_offload = args.offload_model or args.offload_optimizer - context = self.offload_context if self.enable_offload else nullcontext - - with context(): - self.engine = self.prepare_vllm(self.unwrapped_model) - if self.args.sleep_level > 0: - self.engine.engine.sleep(args.sleep_level) - - else: - from swift.llm import PtEngine - self.engine = PtEngine.from_model_template(self.model, self.template, max_batch_size=0) # 0: no limit + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon + self.top_entropy_quantile = args.top_entropy_quantile + self.importance_sampling_level = args.importance_sampling_level + self.enable_offload = False + self.use_gym_env = False + # batch size + self.global_batch_size = args.global_batch_size + self.mini_batch_size = args.mini_batch_size + self.micro_batch_size = args.micro_batch_size def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): # prepare global batch data here @@ -173,6 +133,7 @@ def _replace_data_iterator(self, data_iterator): return iter(res) def forward_step(self, data_iterator, model): + with torch.no_grad(): data = next(data_iterator) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index fed2d318f4..cf629118f7 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -176,24 +176,16 @@ def __init__(self, self.num_generations = args.num_generations self.temperature = args.temperature - self.vllm_mode = args.vllm_mode - self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode - self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode self.loss_type = args.loss_type self.max_completion_length = args.max_completion_length self.completion_length_limit_scope = args.completion_length_limit_scope model.warnings_issued['estimate_tokens'] = True - kwargs['data_collator'] = identity_data_collator # No data collation is needed in GRPO self.shuffle_dataset = args.dataset_shuffle - - self.use_vllm = args.use_vllm - self.async_generate = args.async_generate - vllm_client = kwargs.pop('vllm_client') # for external vllm - self.model_kwarg_keys = ( inspect.signature(model.forward).parameters.keys() if not hasattr(model, 'get_base_model') else inspect.signature(model.get_base_model().forward).parameters.keys()) + self.vllm_client = kwargs.pop('vllm_client') super().__init__(model, ref_model, *_args, **kwargs) if self.args.eval_strategy != 'no': @@ -247,59 +239,7 @@ def __init__(self, set_seed(args.seed, device_specific=True) if is_peft_model(self.model): self.parameter_groups, self.parameter_groups_no_lora = self.split_batches() - self.use_fast_infer = self.use_vllm # whether to use the PT backend - self.vllm_use_async_engine = False - self.enable_offload = False - self.use_gym_env = False - self.enable_server_multi_turn = False - # for multi-turn server, maybe the num of rollout outputs is not equal to the num of rollout inputs - self.dynamic_num_samples = False - if self.use_vllm: - if not is_vllm_available(): - raise ImportError('vLLM is not available and `use_vllm` is set to True. ' - 'Please install vLLM with `pip install vllm -U` to use it.') - if self.vllm_mode == 'server': - self.vllm_client: VLLMClient = vllm_client - if self.accelerator.is_main_process: - self.vllm_client.get_engine_type() - vllm_use_async_engine = [self.vllm_client.use_async_engine] - use_gym_env = [self.vllm_client.use_gym_env] - enable_multi_turn = [self.vllm_client.enable_multi_turn] - else: - vllm_use_async_engine = [False] - use_gym_env = [False] - enable_multi_turn = [self.enable_server_multi_turn] - self.vllm_use_async_engine = broadcast_object_list(vllm_use_async_engine, from_process=0)[0] - self.use_gym_env = broadcast_object_list(use_gym_env, from_process=0)[0] - self.enable_server_multi_turn = broadcast_object_list(enable_multi_turn, from_process=0)[0] - if self.use_gym_env: - self.reward_func_names = ['gym_reward'] - - elif self.vllm_mode == 'colocate': - if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0: - raise ValueError( - f'vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size ' - f'({self.accelerator.num_processes}) evenly.') - - if self.vllm_tensor_parallel_size > 1: - # Create subgroups of ranks for TP, each group with `vllm_tensor_parallel_size` ranks. - # For example, if world_size=8 and vllm_tensor_parallel_size=2 → groups: [0,1], [2,3], [4,5], [6,7] - self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration([ - list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size)) - for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size) - ]) - self.enable_offload = self.args.offload_model or self.args.offload_optimizer - context = self.offload_context if self.enable_offload else nullcontext - - with context(): - self.engine = self.prepare_vllm(model) - if self.args.sleep_level > 0: - self.engine.engine.sleep(self.args.sleep_level) - - else: - from swift.llm import PtEngine - self.engine = PtEngine.from_model_template(self.model, self.template, max_batch_size=0) # 0: no limit - + self._prepare_rollout_engine() if not self.reward_funcs and not self.use_gym_env: raise ValueError('You must specify reward_funcs or reward_model') @@ -2808,3 +2748,64 @@ def _get_last_indices(self, request_ids: List[str]) -> torch.Tensor: for i, rid in enumerate(request_ids): seen[rid] = i return torch.tensor(list(seen.values()), dtype=torch.long, device=self.accelerator.device) + + def _prepare_rollout_engine(self, model): + args = self.args + self.vllm_mode = args.vllm_mode + self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode + self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode + self.use_vllm = args.use_vllm + self.async_generate = args.async_generate + vllm_client = getattr(args, 'vllm_client') or getattr(self, 'vllm_client') # for external vllm + self.use_fast_infer = self.use_vllm # whether to use the PT backend + self.vllm_use_async_engine = False + self.enable_offload = False + self.use_gym_env = False + self.enable_server_multi_turn = False + # for multi-turn server, maybe the num of rollout outputs is not equal to the num of rollout inputs + self.dynamic_num_samples = False + if self.use_vllm: + if not is_vllm_available(): + raise ImportError('vLLM is not available and `use_vllm` is set to True. ' + 'Please install vLLM with `pip install vllm -U` to use it.') + if self.vllm_mode == 'server': + self.vllm_client: VLLMClient = vllm_client + if self.accelerator.is_main_process: + self.vllm_client.get_engine_type() + vllm_use_async_engine = [self.vllm_client.use_async_engine] + use_gym_env = [self.vllm_client.use_gym_env] + enable_multi_turn = [self.vllm_client.enable_multi_turn] + else: + vllm_use_async_engine = [False] + use_gym_env = [False] + enable_multi_turn = [self.enable_server_multi_turn] + self.vllm_use_async_engine = broadcast_object_list(vllm_use_async_engine, from_process=0)[0] + self.use_gym_env = broadcast_object_list(use_gym_env, from_process=0)[0] + self.enable_server_multi_turn = broadcast_object_list(enable_multi_turn, from_process=0)[0] + if self.use_gym_env: + self.reward_func_names = ['gym_reward'] + + elif self.vllm_mode == 'colocate': + if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0: + raise ValueError( + f'vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size ' + f'({self.accelerator.num_processes}) evenly.') + + if self.vllm_tensor_parallel_size > 1: + # Create subgroups of ranks for TP, each group with `vllm_tensor_parallel_size` ranks. + # For example, if world_size=8 and vllm_tensor_parallel_size=2 → groups: [0,1], [2,3], [4,5], [6,7] + self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration([ + list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size)) + for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size) + ]) + self.enable_offload = self.args.offload_model or self.args.offload_optimizer + context = self.offload_context if self.enable_offload else nullcontext + + with context(): + self.engine = self.prepare_vllm(model) + if self.args.sleep_level > 0: + self.engine.engine.sleep(self.args.sleep_level) + + else: + from swift.llm import PtEngine + self.engine = PtEngine.from_model_template(self.model, self.template, max_batch_size=0) # 0: no limit From 3da8756c566bbf8d41069df05052cc3e4a29aa05 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 12 Sep 2025 11:33:40 +0800 Subject: [PATCH 05/83] mega wip --- swift/megatron/argument/megatron_args.py | 1 + swift/megatron/trainers/grpo_trainer.py | 103 ++++++++++++++++++++++- 2 files changed, 102 insertions(+), 2 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 8e8eae5765..dc15561896 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -318,6 +318,7 @@ class MegatronArguments(ExtraMegatronArguments): dataloader_type: Literal['single', 'cyclic', 'external'] = 'cyclic' manual_gc: bool = False manual_gc_interval: int = 0 + use_mbridge: bool = False # learning rate lr: Optional[float] = None diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 8297ac42f5..7cdc740375 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -20,11 +20,15 @@ from .rlhf_base import MegatronRLHFTrainer from .trainer import MegatronTrainer from .utils import get_batch +try: + from mbridge import AutoBridge +except ImportError: + pass logger = get_logger() -class MegatronGRPOTrainer(MegatronRLHFTrainer, GRPOTrainer): +class MegatronGRPOTrainer(MegatronRLHFTrainer): def __init__(self, args: MegatronRLHFArguments): MegatronRLHFTrainer().__init__(self, args) @@ -59,7 +63,7 @@ def __init__(self, args: MegatronRLHFArguments): self.reward_func_names.append(reward_func_name) # TODO: reward model # TODO: multi turn scheduler(colocate multi turn) - self._prepare_rollout_engine + self._prepare_rollout_engine(self.unwrapped_model) self.num_generations = args.num_generations # G in the GRPO paper self.temperature = args.temperature self.loss_type = args.loss_type @@ -74,6 +78,14 @@ def __init__(self, args: MegatronRLHFArguments): self.global_batch_size = args.global_batch_size self.mini_batch_size = args.mini_batch_size self.micro_batch_size = args.micro_batch_size + self.local_path = 'Qwen/Qwen2.5-7B-Instruct' # debug + if args.use_mbridge: + # debug: use mbridge to convert mcore to hf + from transformers import AutoConfig + hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=True) + bridge = AutoBridge.from_pretrained(hf_config) + tf_config = bridge.config + self.bridge = bridge def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): # prepare global batch data here @@ -145,3 +157,90 @@ def forward_step(self, data_iterator, model): ref_logps=ref_logps, labels=data.get('labels'), packed_seq_params=data.get('packed_seq_params')) + + def _prepare_rollout_engine(self, model): + args = self.args + self.vllm_mode = args.vllm_mode + self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode + self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode + self.use_vllm = args.use_vllm + self.async_generate = args.async_generate + vllm_client = getattr(args, 'vllm_client') or getattr(self, 'vllm_client') # for external vllm + self.use_fast_infer = self.use_vllm # whether to use the PT backend + self.vllm_use_async_engine = False + self.enable_offload = False + self.use_gym_env = False + self.enable_server_multi_turn = False + # for multi-turn server, maybe the num of rollout outputs is not equal to the num of rollout inputs + self.dynamic_num_samples = False + if self.use_vllm: + if not is_vllm_available(): + raise ImportError('vLLM is not available and `use_vllm` is set to True. ' + 'Please install vLLM with `pip install vllm -U` to use it.') + if self.vllm_mode == 'server': + self.vllm_client: VLLMClient = vllm_client + if self.accelerator.is_main_process: + self.vllm_client.get_engine_type() + vllm_use_async_engine = [self.vllm_client.use_async_engine] + use_gym_env = [self.vllm_client.use_gym_env] + enable_multi_turn = [self.vllm_client.enable_multi_turn] + else: + vllm_use_async_engine = [False] + use_gym_env = [False] + enable_multi_turn = [self.enable_server_multi_turn] + self.vllm_use_async_engine = broadcast_object_list(vllm_use_async_engine, from_process=0)[0] + self.use_gym_env = broadcast_object_list(use_gym_env, from_process=0)[0] + self.enable_server_multi_turn = broadcast_object_list(enable_multi_turn, from_process=0)[0] + if self.use_gym_env: + self.reward_func_names = ['gym_reward'] + + elif self.vllm_mode == 'colocate': + if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0: + raise ValueError( + f'vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size ' + f'({self.accelerator.num_processes}) evenly.') + + if self.vllm_tensor_parallel_size > 1: + # Create subgroups of ranks for TP, each group with `vllm_tensor_parallel_size` ranks. + # For example, if world_size=8 and vllm_tensor_parallel_size=2 → groups: [0,1], [2,3], [4,5], [6,7] + self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration([ + list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size)) + for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size) + ]) + self.enable_offload = self.args.offload_model or self.args.offload_optimizer + context = self.offload_context if self.enable_offload else nullcontext + + with context(): + self.engine = self.prepare_vllm(model) + if self.args.sleep_level > 0: + self.engine.engine.sleep(self.args.sleep_level) + + + def prepare_vllm(self, model): + from swift.tuners import Swift + from swift.llm.infer.infer_engine import GRPOVllmEngine + max_num_seqs = ( + self.args.per_device_train_batch_size * self.vllm_tensor_parallel_size * self.args.steps_per_generation) + data_parallel_rank = mpu.get_data_parallel_rank() + + with Swift.grpo_context(model, self.template.processor): + engine = GRPOVllmEngine( + model.model_dir, + model.model_info.torch_dtype, + model_type=model.model_meta.model_type, + use_async_engine=False, # TODO: async engine for colocate + tensor_parallel_size=self.vllm_tensor_parallel_size, + gpu_memory_utilization=self.vllm_gpu_memory_utilization, + enable_prefix_caching=self.args.vllm_enable_prefix_caching, + max_num_seqs=max_num_seqs, + enforce_eager=self.args.vllm_enforce_eager, + limit_mm_per_prompt=self.args.vllm_limit_mm_per_prompt, + enable_sleep_mode=self.args.sleep_level > 0, + max_model_len=self.args.vllm_max_model_len, + seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, + disable_cascade_attn=self.args.vllm_disable_cascade_attn, + load_format='dummy', + template=copy(self.template), + distributed_executor_backend='external_launcher', + ) + return engine From d9ec029c099c12d869a9ec2894708e69dc8dc001 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 17 Sep 2025 21:07:43 +0800 Subject: [PATCH 06/83] wip --- docs/source/_extra/url_aliases.py | 54 +++---- swift/megatron/train/sft.py | 2 +- swift/megatron/trainers/base.py | 3 +- swift/megatron/trainers/grpo_trainer.py | 187 +++++++++++++----------- swift/megatron/trainers/rlhf_base.py | 5 + 5 files changed, 139 insertions(+), 112 deletions(-) diff --git a/docs/source/_extra/url_aliases.py b/docs/source/_extra/url_aliases.py index daa18e0173..75faf906c2 100644 --- a/docs/source/_extra/url_aliases.py +++ b/docs/source/_extra/url_aliases.py @@ -6,52 +6,52 @@ # 常见问题 'faq': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/常见问题整理.html', '常见问题': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/常见问题整理.html', - + # 支持的模型和数据集 'models': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/支持的模型和数据集.html', '模型列表': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/支持的模型和数据集.html', '数据集列表': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/支持的模型和数据集.html', - + # 命令行参数 'params': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/命令行参数.html', '命令行参数': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/命令行参数.html', '参数说明': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/命令行参数.html', - + # 自定义数据集 'custom-dataset': 'https://swift.readthedocs.io/zh-cn/latest/Customization/自定义数据集.html', '自定义数据集': 'https://swift.readthedocs.io/zh-cn/latest/Customization/自定义数据集.html', - + # 推理和部署 'deploy': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/推理和部署.html', '推理部署': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/推理和部署.html', '部署': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/推理和部署.html', - + # 评测 'eval': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/评测.html', '评测': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/评测.html', - + # 预训练与微调 'training': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/预训练与微调.html', '训练': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/预训练与微调.html', '微调': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/预训练与微调.html', - + # SWIFT安装 'install': 'https://swift.readthedocs.io/zh-cn/latest/GetStarted/SWIFT安装.html', '安装': 'https://swift.readthedocs.io/zh-cn/latest/GetStarted/SWIFT安装.html', - + # 快速开始 'quickstart': 'https://swift.readthedocs.io/zh-cn/latest/GetStarted/快速开始.html', '快速开始': 'https://swift.readthedocs.io/zh-cn/latest/GetStarted/快速开始.html', - + # 多模态 'multimodal': 'https://swift.readthedocs.io/zh-cn/latest/Multi-Modal/index.html', '多模态': 'https://swift.readthedocs.io/zh-cn/latest/Multi-Modal/index.html', - + # 强化学习 'rl': 'https://swift.readthedocs.io/zh-cn/latest/RLHF/index.html', '强化学习': 'https://swift.readthedocs.io/zh-cn/latest/RLHF/index.html', 'RLHF': 'https://swift.readthedocs.io/zh-cn/latest/RLHF/index.html', - + # 自定义 'custom': 'https://swift.readthedocs.io/zh-cn/latest/Customization/index.html', '自定义': 'https://swift.readthedocs.io/zh-cn/latest/Customization/index.html', @@ -62,65 +62,66 @@ # Frequently Asked Questions 'faq': 'https://swift.readthedocs.io/en/latest/Instruction/Frequently-asked-questions.html', 'frequently-asked-questions': 'https://swift.readthedocs.io/en/latest/Instruction/Frequently-asked-questions.html', - + # Supported Models and Datasets 'models': 'https://swift.readthedocs.io/en/latest/Instruction/Supported-models-and-datasets.html', 'supported-models': 'https://swift.readthedocs.io/en/latest/Instruction/Supported-models-and-datasets.html', 'datasets': 'https://swift.readthedocs.io/en/latest/Instruction/Supported-models-and-datasets.html', - + # Command Line Parameters 'params': 'https://swift.readthedocs.io/en/latest/Instruction/Command-line-parameters.html', 'command-line-parameters': 'https://swift.readthedocs.io/en/latest/Instruction/Command-line-parameters.html', 'parameters': 'https://swift.readthedocs.io/en/latest/Instruction/Command-line-parameters.html', - + # Custom Dataset 'custom-dataset': 'https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html', 'custom-datasets': 'https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html', - + # Inference and Deployment 'deploy': 'https://swift.readthedocs.io/en/latest/Instruction/Inference-and-deployment.html', 'inference': 'https://swift.readthedocs.io/en/latest/Instruction/Inference-and-deployment.html', 'deployment': 'https://swift.readthedocs.io/en/latest/Instruction/Inference-and-deployment.html', - + # Evaluation 'eval': 'https://swift.readthedocs.io/en/latest/Instruction/Evaluation.html', 'evaluation': 'https://swift.readthedocs.io/en/latest/Instruction/Evaluation.html', - + # Pre-training and Fine-tuning 'training': 'https://swift.readthedocs.io/en/latest/Instruction/Pre-training-and-fine-tuning.html', 'pre-training': 'https://swift.readthedocs.io/en/latest/Instruction/Pre-training-and-fine-tuning.html', 'fine-tuning': 'https://swift.readthedocs.io/en/latest/Instruction/Pre-training-and-fine-tuning.html', - + # SWIFT Installation 'install': 'https://swift.readthedocs.io/en/latest/GetStarted/SWIFT-installation.html', 'installation': 'https://swift.readthedocs.io/en/latest/GetStarted/SWIFT-installation.html', - + # Quick Start 'quickstart': 'https://swift.readthedocs.io/en/latest/GetStarted/Quick-start.html', 'quick-start': 'https://swift.readthedocs.io/en/latest/GetStarted/Quick-start.html', - + # Multi-Modal 'multimodal': 'https://swift.readthedocs.io/en/latest/Multi-Modal/index.html', 'multi-modal': 'https://swift.readthedocs.io/en/latest/Multi-Modal/index.html', - + # Reinforcement Learning 'rl': 'https://swift.readthedocs.io/en/latest/RLHF/index.html', 'rlhf': 'https://swift.readthedocs.io/en/latest/RLHF/index.html', 'reinforcement-learning': 'https://swift.readthedocs.io/en/latest/RLHF/index.html', - + # Customization 'custom': 'https://swift.readthedocs.io/en/latest/Customization/index.html', 'customization': 'https://swift.readthedocs.io/en/latest/Customization/index.html', } + def get_url_alias(alias_key, language='zh'): """ 获取URL别名对应的完整URL - + Args: alias_key (str): 别名键 language (str): 语言,'zh' 或 'en' - + Returns: str: 完整的URL,如果找不到别名则返回None """ @@ -130,13 +131,14 @@ def get_url_alias(alias_key, language='zh'): return EN_URL_ALIASES.get(alias_key) return None + def get_all_aliases(language='zh'): """ 获取所有URL别名 - + Args: language (str): 语言,'zh' 或 'en' - + Returns: dict: 所有别名的字典 """ diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 9e000dbfcd..1dd6ebc4c8 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -21,7 +21,7 @@ class MegatronSft(SwiftSft): args: args_class def prepare_trainer(self): - return MegatronTrainer(self.args) + return MegatronTrainer(self.args, self.template) def __init__(self, args: Optional[Union[List[str], MegatronTrainArguments]] = None) -> None: self.train_msg = {} diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 16e3ba4410..a983d20598 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -37,8 +37,9 @@ class BaseMegatronTrainer(ABC): - def __init__(self, args): + def __init__(self, args, template): self.args = args + self.template = template self.stimer = StragglerDetector() logging_path = os.path.join(args.save, 'logging.jsonl') logger.info(f'logging_path: {logging_path}') diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 7cdc740375..bdb891e0f8 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -20,6 +20,7 @@ from .rlhf_base import MegatronRLHFTrainer from .trainer import MegatronTrainer from .utils import get_batch + try: from mbridge import AutoBridge except ImportError: @@ -30,61 +31,22 @@ class MegatronGRPOTrainer(MegatronRLHFTrainer): - def __init__(self, args: MegatronRLHFArguments): + def __init__(self, args: MegatronRLHFArguments, template): MegatronRLHFTrainer().__init__(self, args) - # TODO: init vllm self.args = args - self.processing_class = self.processor + self.hf_model_dir = args.model_info.model_dir + self.process_index = torch.distributed.get_rank() + self.processing_class = self.template.processor # set up reward funcs - reward_funcs = args.reward_funcs - if not isinstance(reward_funcs, list): - reward_funcs = [reward_funcs] - if reward_funcs: - for i, reward_func in enumerate(reward_funcs): - if reward_func in orms: - reward_func_class = orms[reward_func] - reward_func_args = list(inspect.signature(reward_func_class.__init__).parameters) - reward_func_kwargs = { - key: getattr(args, key) - for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key) - } - if 'tokenizer' in reward_func_args: - reward_func_kwargs['tokenizer'] = self.processing_class - reward_funcs[i] = reward_func_class(**reward_func_kwargs) - elif not callable(reward_func): - raise ValueError(f'reward_function {reward_func} is not implemented in swift.plugin') - self.reward_funcs = reward_funcs - self.reward_func_names = [] - for reward_func in reward_funcs: - if inspect.isfunction(reward_func): - reward_func_name = reward_func.__name__ - else: - reward_func_name = reward_func.__class__.__name__ - self.reward_func_names.append(reward_func_name) - # TODO: reward model + self.prepare_rewards() # TODO: multi turn scheduler(colocate multi turn) self._prepare_rollout_engine(self.unwrapped_model) - self.num_generations = args.num_generations # G in the GRPO paper - self.temperature = args.temperature - self.loss_type = args.loss_type - self.max_completion_length = args.max_completion_length - self.epsilon_low = args.epsilon - self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon - self.top_entropy_quantile = args.top_entropy_quantile - self.importance_sampling_level = args.importance_sampling_level - self.enable_offload = False - self.use_gym_env = False - # batch size - self.global_batch_size = args.global_batch_size - self.mini_batch_size = args.mini_batch_size - self.micro_batch_size = args.micro_batch_size - self.local_path = 'Qwen/Qwen2.5-7B-Instruct' # debug + self._init_grpo_params() if args.use_mbridge: # debug: use mbridge to convert mcore to hf from transformers import AutoConfig - hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=True) + hf_config = AutoConfig.from_pretrained(self.hf_model_dir, trust_remote_code=True) bridge = AutoBridge.from_pretrained(hf_config) - tf_config = bridge.config self.bridge = bridge def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): @@ -144,19 +106,21 @@ def _replace_data_iterator(self, data_iterator): res.append(self.ref_forward(ref_model, data_iterator)) return iter(res) - def forward_step(self, data_iterator, model): - - with torch.no_grad(): - data = next(data_iterator) - - ref_logps = data.pop('logps') - with self.stimer: - output_tensor = model(**data) - return output_tensor, partial( - self.loss_func, - ref_logps=ref_logps, - labels=data.get('labels'), - packed_seq_params=data.get('packed_seq_params')) + def _init_grpo_params(self): + args = self.args + self.num_generations = args.num_generations # G in the GRPO paper + self.temperature = args.temperature + self.loss_type = args.loss_type + self.max_completion_length = args.max_completion_length + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon + self.top_entropy_quantile = args.top_entropy_quantile + self.importance_sampling_level = args.importance_sampling_level + self.enable_offload = False + # batch size + self.global_batch_size = args.global_batch_size + self.mini_batch_size = args.mini_batch_size + self.micro_batch_size = args.micro_batch_size def _prepare_rollout_engine(self, model): args = self.args @@ -211,36 +175,91 @@ def _prepare_rollout_engine(self, model): context = self.offload_context if self.enable_offload else nullcontext with context(): - self.engine = self.prepare_vllm(model) + self.engine = self.prepare_vllm() if self.args.sleep_level > 0: self.engine.engine.sleep(self.args.sleep_level) - - def prepare_vllm(self, model): + def prepare_vllm(self): from swift.tuners import Swift from swift.llm.infer.infer_engine import GRPOVllmEngine + args = self.args max_num_seqs = ( self.args.per_device_train_batch_size * self.vllm_tensor_parallel_size * self.args.steps_per_generation) - data_parallel_rank = mpu.get_data_parallel_rank() - - with Swift.grpo_context(model, self.template.processor): - engine = GRPOVllmEngine( - model.model_dir, - model.model_info.torch_dtype, - model_type=model.model_meta.model_type, - use_async_engine=False, # TODO: async engine for colocate - tensor_parallel_size=self.vllm_tensor_parallel_size, - gpu_memory_utilization=self.vllm_gpu_memory_utilization, - enable_prefix_caching=self.args.vllm_enable_prefix_caching, - max_num_seqs=max_num_seqs, - enforce_eager=self.args.vllm_enforce_eager, - limit_mm_per_prompt=self.args.vllm_limit_mm_per_prompt, - enable_sleep_mode=self.args.sleep_level > 0, - max_model_len=self.args.vllm_max_model_len, - seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, - disable_cascade_attn=self.args.vllm_disable_cascade_attn, - load_format='dummy', - template=copy(self.template), - distributed_executor_backend='external_launcher', - ) + + engine = GRPOVllmEngine( + self.hf_model_dir, + args.torch_dtype, + model_type=args.model_type, + use_async_engine=False, # TODO: async engine for colocate + tensor_parallel_size=self.vllm_tensor_parallel_size, + gpu_memory_utilization=self.vllm_gpu_memory_utilization, + enable_prefix_caching=self.args.vllm_enable_prefix_caching, + max_num_seqs=max_num_seqs, + enforce_eager=self.args.vllm_enforce_eager, + limit_mm_per_prompt=self.args.vllm_limit_mm_per_prompt, + enable_sleep_mode=self.args.sleep_level > 0, + max_model_len=self.args.vllm_max_model_len, + seed=self.process_index // self.vllm_tensor_parallel_size, + disable_cascade_attn=self.args.vllm_disable_cascade_attn, + load_format='dummy', + template=self.template, + distributed_executor_backend='external_launcher', + ) return engine + + def prepare_rewards(self): + # TODO: reward model + args = self.args + reward_funcs = args.reward_funcs + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + if reward_funcs: + for i, reward_func in enumerate(reward_funcs): + if reward_func in orms: + reward_func_class = orms[reward_func] + reward_func_args = list(inspect.signature(reward_func_class.__init__).parameters) + reward_func_kwargs = { + key: getattr(args, key) + for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key) + } + if 'tokenizer' in reward_func_args: + reward_func_kwargs['tokenizer'] = self.processing_class + reward_funcs[i] = reward_func_class(**reward_func_kwargs) + elif not callable(reward_func): + raise ValueError(f'reward_function {reward_func} is not implemented in swift.plugin') + self.reward_funcs = reward_funcs + self.reward_func_names = [] + for reward_func in reward_funcs: + if inspect.isfunction(reward_func): + reward_func_name = reward_func.__name__ + else: + reward_func_name = reward_func.__class__.__name__ + self.reward_func_names.append(reward_func_name) + + def _move_model_to_vllm(self): + # TODO: LoRA, server + per_tensor_params = self.bridge.export_weights(self.unwrapped_model) + self.engine.inner_model.load_weights(per_tensor_params) + + def forward_step(self, data_iterator, model): + # train_batch_size + + # step1: rollout + + # step2: compute old logps + + # step3: compute ref logps + + # step4: compute rewards/advantages + + # return: output_tensor, loss_func + + data = next(data_iterator) + ref_logps = data.pop('logps') + with self.stimer: + output_tensor = model(**data) + return output_tensor, partial( + self.loss_func, + ref_logps=ref_logps, + labels=data.get('labels'), + packed_seq_params=data.get('packed_seq_params')) diff --git a/swift/megatron/trainers/rlhf_base.py b/swift/megatron/trainers/rlhf_base.py index 26064df297..daa5d752e9 100644 --- a/swift/megatron/trainers/rlhf_base.py +++ b/swift/megatron/trainers/rlhf_base.py @@ -107,3 +107,8 @@ def null_ref_context(self): yield ref_model if args.ref_adapter_load: self.peft_model.set_adapter('default') + + @contextmanager + def offload_context(self): + # TODO: offload + yield From 7c56f9f84dc92e5cf3bbe31576c4575959d14528 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 17 Sep 2025 21:22:11 +0800 Subject: [PATCH 07/83] override train_step wip --- swift/megatron/trainers/rlhf_base.py | 6 +- swift/megatron/trainers/utils.py | 146 ++++++++++++++++++++++++++- 2 files changed, 150 insertions(+), 2 deletions(-) diff --git a/swift/megatron/trainers/rlhf_base.py b/swift/megatron/trainers/rlhf_base.py index daa5d752e9..0930dce228 100644 --- a/swift/megatron/trainers/rlhf_base.py +++ b/swift/megatron/trainers/rlhf_base.py @@ -13,7 +13,7 @@ from swift.utils import get_current_device, get_logger from .trainer import MegatronTrainer -from .utils import get_batch +from .utils import get_batch, train_step logger = get_logger() @@ -112,3 +112,7 @@ def null_ref_context(self): def offload_context(self): # TODO: offload yield + + def _patch_megatron(self): + super()._patch_megatron() + self._origin_train_step = train_step diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index b5a0c7dde6..0f8014cbfb 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -3,9 +3,16 @@ import torch from megatron.core import mpu +from megatron.core.num_microbatches_calculator import get_num_microbatches from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.core.utils import get_batch_on_this_cp_rank as mcore_get_batch_on_this_cp_rank -from megatron.training import get_args +from megatron.training import get_args, get_timers +from megatron.training.training import (cuda_graph_capture, cuda_graph_set_manual_hooks, + get_tensor_shapes_adjust_fn_for_distillation, has_nvidia_modelopt) +from megatron.training.utils import (logical_and_across_model_parallel_group, + reduce_max_stat_across_model_parallel_group, unwrap_model) from swift.llm import get_packed_seq_params as _get_packed_seq_params from swift.llm import to_device @@ -151,3 +158,140 @@ def get_batch(data_iterator): # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) return batch + + +def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): + # borrowed from Megatron-LM 0.13.rc2 + """Single training step.""" + args = get_args() + timers = get_timers() + + # CUDA Graph capturing only executes once, when it's the first training iteration. + if args.curr_iteration == args.iteration and args.external_cuda_graph: + cuda_graph_capture(model, config, args) + + # Set grad to zero. + for model_chunk in model: + model_chunk.zero_grad_buffer() + optimizer.zero_grad() + + # Collect garbage and empty unused memory. + gc.collect() + torch.cuda.empty_cache() + + rerun_state_machine = get_rerun_state_machine() + while rerun_state_machine.should_run_forward_backward(data_iterator): + # Set grad to zero. + for model_chunk in model: + model_chunk.zero_grad_buffer() + optimizer.zero_grad() + + if has_nvidia_modelopt: + # [ModelOpt]: Pipeline-parallel Distillation stacks student and teacher tensors + adjust_tensor_shapes_fn = get_tensor_shapes_adjust_fn_for_distillation(model, args.seq_length, + args.micro_batch_size, + args.decoder_seq_length) + else: + adjust_tensor_shapes_fn = None + + # Forward pass. + forward_backward_func = get_forward_backward_func() + losses_reduced = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iterator, + model=model, + num_microbatches=get_num_microbatches(), + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + decoder_seq_length=args.decoder_seq_length, + forward_only=False, + adjust_tensor_shapes_fn=adjust_tensor_shapes_fn, + ) + should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit() + if should_exit: + return {}, True, should_checkpoint, should_exit, exit_code, None, None + + # Empty unused memory. + if args.empty_unused_memory_level >= 1: + torch.cuda.empty_cache() + + # Vision gradients. + if args.vision_pretraining and args.vision_pretraining_type == 'dino': + unwrapped_model = unwrap_model(model[0]) + unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) + + # Update parameters. + + timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time) + update_successful, grad_norm, num_zeros_in_grad = optimizer.step() + timers('optimizer').stop() + + # when freezing sub-models we may have a mixture of successful and unsucessful ranks, + # so we must gather across mp ranks + update_successful = logical_and_across_model_parallel_group(update_successful) + # grad_norm and num_zeros_in_grad will be None on ranks without trainable params, + # so we must gather across mp ranks + grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm) + if args.log_num_zeros_in_grad: + num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad) + + # Vision momentum. + if args.vision_pretraining and args.vision_pretraining_type == 'dino': + unwrapped_model = unwrap_model(model[0]) + unwrapped_model.update_momentum(args.curr_iteration) + + # Update learning rate. + if update_successful: + increment = get_num_microbatches() * args.micro_batch_size * args.data_parallel_size + opt_param_scheduler.step(increment=increment) + skipped_iter = 0 + else: + skipped_iter = 1 + + # Empty unused memory. + if args.empty_unused_memory_level >= 2: + torch.cuda.empty_cache() + + # Set the manual hooks when CUDA Graphs are enabled. + if args.curr_iteration == args.iteration and args.external_cuda_graph: + if args.use_distributed_optimizer and args.overlap_param_gather: + cuda_graph_set_manual_hooks(model) + + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Average loss across microbatches. + loss_reduced = {} + + for key in losses_reduced[0].keys(): + val = [x[key].view(-1) for x in losses_reduced] + if val[0].numel() == 2: + if args.sft: + # in mcore the normalization happens on micro batch instead of global + val = torch.vstack(val) + val = val[:, 0] / val[:, 1] + val = val.mean() + torch.distributed.all_reduce(val, group=mpu.get_data_parallel_group(with_context_parallel=True)) + val /= torch.distributed.get_world_size( + group=mpu.get_data_parallel_group(with_context_parallel=True)) + loss_reduced[key] = val + else: + # there is one dict per microbatch. in new reporting, we average + # over the total number of tokens across the global batch. + val = torch.vstack(val).sum(dim=0) + torch.distributed.all_reduce(val, group=mpu.get_data_parallel_group(with_context_parallel=True)) + loss_reduced[key] = val[0] / val[1] + elif val[0].numel() == 1: + # legacy behavior, we average over the number of microbatches + val = torch.cat(val).mean() + loss_reduced[key] = val + else: + raise ValueError(f'Invalid value shape: {val[0].shape} for key {key}') + return ( + loss_reduced, + skipped_iter, + should_checkpoint, + should_exit, + exit_code, + grad_norm, + num_zeros_in_grad, + ) + return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad From 686fc74262ac47e75ff8e9b09872481dab09dfe2 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 18 Sep 2025 11:33:54 +0800 Subject: [PATCH 08/83] remove override train_step to grpo --- swift/megatron/trainers/grpo_trainer.py | 157 +++++++++++++++++++- swift/megatron/trainers/rlhf_base.py | 6 +- swift/megatron/trainers/utils.py | 147 +----------------- swift/trainers/rlhf_trainer/grpo_trainer.py | 125 ++++++++-------- 4 files changed, 213 insertions(+), 222 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index bdb891e0f8..d570b4af0a 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -1,16 +1,24 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import gc import inspect from collections import namedtuple from contextlib import contextmanager, nullcontext from functools import partial +from typing import Any, Dict import torch from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed from megatron.core import mpu from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank -from megatron.training import get_args, get_model, training +from megatron.core.num_microbatches_calculator import get_num_microbatches +from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron.core.rerun_state_machine import get_rerun_state_machine +from megatron.training import get_args, get_model, get_timers, training from megatron.training.checkpointing import load_checkpoint -from megatron.training.utils import unwrap_model +from megatron.training.training import (cuda_graph_capture, cuda_graph_set_manual_hooks, + get_tensor_shapes_adjust_fn_for_distillation, has_nvidia_modelopt) +from megatron.training.utils import (logical_and_across_model_parallel_group, + reduce_max_stat_across_model_parallel_group, unwrap_model) from torch.distributed.nn import all_reduce from swift.plugin import orms @@ -49,12 +57,6 @@ def __init__(self, args: MegatronRLHFArguments, template): bridge = AutoBridge.from_pretrained(hf_config) self.bridge = bridge - def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): - # prepare global batch data here - new_data_iterator = self._replace_data_iterator(data_iterator) - return self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler, - config) - def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, labels: torch.Tensor, packed_seq_params): args = get_args() @@ -263,3 +265,142 @@ def forward_step(self, data_iterator, model): ref_logps=ref_logps, labels=data.get('labels'), packed_seq_params=data.get('packed_seq_params')) + + def _patch_megatron(self): + super()._patch_megatron() + self._origin_train_step = self.train_step + + def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): + # borrowed from Megatron-LM 0.13 + """Single training step.""" + args = get_args() + timers = get_timers() + + # CUDA Graph capturing only executes once, when it's the first training iteration. + if args.curr_iteration == args.iteration and args.external_cuda_graph: + cuda_graph_capture(model, config, args) + + # Set grad to zero. + for model_chunk in model: + model_chunk.zero_grad_buffer() + optimizer.zero_grad() + + # Collect garbage and empty unused memory. + gc.collect() + torch.cuda.empty_cache() + + rerun_state_machine = get_rerun_state_machine() + while rerun_state_machine.should_run_forward_backward(data_iterator): + # Set grad to zero. + for model_chunk in model: + model_chunk.zero_grad_buffer() + optimizer.zero_grad() + + if has_nvidia_modelopt: + # [ModelOpt]: Pipeline-parallel Distillation stacks student and teacher tensors + adjust_tensor_shapes_fn = get_tensor_shapes_adjust_fn_for_distillation( + model, args.seq_length, args.micro_batch_size, args.decoder_seq_length) + else: + adjust_tensor_shapes_fn = None + + # Forward pass. + forward_backward_func = get_forward_backward_func() + losses_reduced = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iterator, + model=model, + num_microbatches=get_num_microbatches(), + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + decoder_seq_length=args.decoder_seq_length, + forward_only=False, + adjust_tensor_shapes_fn=adjust_tensor_shapes_fn, + ) + should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit() + if should_exit: + return {}, True, should_checkpoint, should_exit, exit_code, None, None + + # Empty unused memory. + if args.empty_unused_memory_level >= 1: + torch.cuda.empty_cache() + + # Vision gradients. + if args.vision_pretraining and args.vision_pretraining_type == 'dino': + unwrapped_model = unwrap_model(model[0]) + unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) + + # Update parameters. + + timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time) + update_successful, grad_norm, num_zeros_in_grad = optimizer.step() + timers('optimizer').stop() + + # when freezing sub-models we may have a mixture of successful and unsucessful ranks, + # so we must gather across mp ranks + update_successful = logical_and_across_model_parallel_group(update_successful) + # grad_norm and num_zeros_in_grad will be None on ranks without trainable params, + # so we must gather across mp ranks + grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm) + if args.log_num_zeros_in_grad: + num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad) + + # Vision momentum. + if args.vision_pretraining and args.vision_pretraining_type == 'dino': + unwrapped_model = unwrap_model(model[0]) + unwrapped_model.update_momentum(args.curr_iteration) + + # Update learning rate. + if update_successful: + increment = get_num_microbatches() * args.micro_batch_size * args.data_parallel_size + opt_param_scheduler.step(increment=increment) + skipped_iter = 0 + else: + skipped_iter = 1 + + # Empty unused memory. + if args.empty_unused_memory_level >= 2: + torch.cuda.empty_cache() + + # Set the manual hooks when CUDA Graphs are enabled. + if args.curr_iteration == args.iteration and args.external_cuda_graph: + if args.use_distributed_optimizer and args.overlap_param_gather: + cuda_graph_set_manual_hooks(model) + + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Average loss across microbatches. + loss_reduced = {} + + for key in losses_reduced[0].keys(): + val = [x[key].view(-1) for x in losses_reduced] + if val[0].numel() == 2: + if args.sft: + # in mcore the normalization happens on micro batch instead of global + val = torch.vstack(val) + val = val[:, 0] / val[:, 1] + val = val.mean() + torch.distributed.all_reduce(val, group=mpu.get_data_parallel_group(with_context_parallel=True)) + val /= torch.distributed.get_world_size( + group=mpu.get_data_parallel_group(with_context_parallel=True)) + loss_reduced[key] = val + else: + # there is one dict per microbatch. in new reporting, we average + # over the total number of tokens across the global batch. + val = torch.vstack(val).sum(dim=0) + torch.distributed.all_reduce(val, group=mpu.get_data_parallel_group(with_context_parallel=True)) + loss_reduced[key] = val[0] / val[1] + elif val[0].numel() == 1: + # legacy behavior, we average over the number of microbatches + val = torch.cat(val).mean() + loss_reduced[key] = val + else: + raise ValueError(f'Invalid value shape: {val[0].shape} for key {key}') + return ( + loss_reduced, + skipped_iter, + should_checkpoint, + should_exit, + exit_code, + grad_norm, + num_zeros_in_grad, + ) + return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad diff --git a/swift/megatron/trainers/rlhf_base.py b/swift/megatron/trainers/rlhf_base.py index 0930dce228..daa5d752e9 100644 --- a/swift/megatron/trainers/rlhf_base.py +++ b/swift/megatron/trainers/rlhf_base.py @@ -13,7 +13,7 @@ from swift.utils import get_current_device, get_logger from .trainer import MegatronTrainer -from .utils import get_batch, train_step +from .utils import get_batch logger = get_logger() @@ -112,7 +112,3 @@ def null_ref_context(self): def offload_context(self): # TODO: offload yield - - def _patch_megatron(self): - super()._patch_megatron() - self._origin_train_step = train_step diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 0f8014cbfb..4f40f6023e 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -3,20 +3,12 @@ import torch from megatron.core import mpu -from megatron.core.num_microbatches_calculator import get_num_microbatches from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.pipeline_parallel import get_forward_backward_func -from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.core.utils import get_batch_on_this_cp_rank as mcore_get_batch_on_this_cp_rank -from megatron.training import get_args, get_timers -from megatron.training.training import (cuda_graph_capture, cuda_graph_set_manual_hooks, - get_tensor_shapes_adjust_fn_for_distillation, has_nvidia_modelopt) -from megatron.training.utils import (logical_and_across_model_parallel_group, - reduce_max_stat_across_model_parallel_group, unwrap_model) +from megatron.training import get_args from swift.llm import get_packed_seq_params as _get_packed_seq_params from swift.llm import to_device -from swift.utils import get_current_device def get_swift_datasets_provider(train_dataset, val_dataset): @@ -158,140 +150,3 @@ def get_batch(data_iterator): # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) return batch - - -def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): - # borrowed from Megatron-LM 0.13.rc2 - """Single training step.""" - args = get_args() - timers = get_timers() - - # CUDA Graph capturing only executes once, when it's the first training iteration. - if args.curr_iteration == args.iteration and args.external_cuda_graph: - cuda_graph_capture(model, config, args) - - # Set grad to zero. - for model_chunk in model: - model_chunk.zero_grad_buffer() - optimizer.zero_grad() - - # Collect garbage and empty unused memory. - gc.collect() - torch.cuda.empty_cache() - - rerun_state_machine = get_rerun_state_machine() - while rerun_state_machine.should_run_forward_backward(data_iterator): - # Set grad to zero. - for model_chunk in model: - model_chunk.zero_grad_buffer() - optimizer.zero_grad() - - if has_nvidia_modelopt: - # [ModelOpt]: Pipeline-parallel Distillation stacks student and teacher tensors - adjust_tensor_shapes_fn = get_tensor_shapes_adjust_fn_for_distillation(model, args.seq_length, - args.micro_batch_size, - args.decoder_seq_length) - else: - adjust_tensor_shapes_fn = None - - # Forward pass. - forward_backward_func = get_forward_backward_func() - losses_reduced = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=data_iterator, - model=model, - num_microbatches=get_num_microbatches(), - seq_length=args.seq_length, - micro_batch_size=args.micro_batch_size, - decoder_seq_length=args.decoder_seq_length, - forward_only=False, - adjust_tensor_shapes_fn=adjust_tensor_shapes_fn, - ) - should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit() - if should_exit: - return {}, True, should_checkpoint, should_exit, exit_code, None, None - - # Empty unused memory. - if args.empty_unused_memory_level >= 1: - torch.cuda.empty_cache() - - # Vision gradients. - if args.vision_pretraining and args.vision_pretraining_type == 'dino': - unwrapped_model = unwrap_model(model[0]) - unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) - - # Update parameters. - - timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time) - update_successful, grad_norm, num_zeros_in_grad = optimizer.step() - timers('optimizer').stop() - - # when freezing sub-models we may have a mixture of successful and unsucessful ranks, - # so we must gather across mp ranks - update_successful = logical_and_across_model_parallel_group(update_successful) - # grad_norm and num_zeros_in_grad will be None on ranks without trainable params, - # so we must gather across mp ranks - grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm) - if args.log_num_zeros_in_grad: - num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad) - - # Vision momentum. - if args.vision_pretraining and args.vision_pretraining_type == 'dino': - unwrapped_model = unwrap_model(model[0]) - unwrapped_model.update_momentum(args.curr_iteration) - - # Update learning rate. - if update_successful: - increment = get_num_microbatches() * args.micro_batch_size * args.data_parallel_size - opt_param_scheduler.step(increment=increment) - skipped_iter = 0 - else: - skipped_iter = 1 - - # Empty unused memory. - if args.empty_unused_memory_level >= 2: - torch.cuda.empty_cache() - - # Set the manual hooks when CUDA Graphs are enabled. - if args.curr_iteration == args.iteration and args.external_cuda_graph: - if args.use_distributed_optimizer and args.overlap_param_gather: - cuda_graph_set_manual_hooks(model) - - if mpu.is_pipeline_last_stage(ignore_virtual=True): - # Average loss across microbatches. - loss_reduced = {} - - for key in losses_reduced[0].keys(): - val = [x[key].view(-1) for x in losses_reduced] - if val[0].numel() == 2: - if args.sft: - # in mcore the normalization happens on micro batch instead of global - val = torch.vstack(val) - val = val[:, 0] / val[:, 1] - val = val.mean() - torch.distributed.all_reduce(val, group=mpu.get_data_parallel_group(with_context_parallel=True)) - val /= torch.distributed.get_world_size( - group=mpu.get_data_parallel_group(with_context_parallel=True)) - loss_reduced[key] = val - else: - # there is one dict per microbatch. in new reporting, we average - # over the total number of tokens across the global batch. - val = torch.vstack(val).sum(dim=0) - torch.distributed.all_reduce(val, group=mpu.get_data_parallel_group(with_context_parallel=True)) - loss_reduced[key] = val[0] / val[1] - elif val[0].numel() == 1: - # legacy behavior, we average over the number of microbatches - val = torch.cat(val).mean() - loss_reduced[key] = val - else: - raise ValueError(f'Invalid value shape: {val[0].shape} for key {key}') - return ( - loss_reduced, - skipped_iter, - should_checkpoint, - should_exit, - exit_code, - grad_norm, - num_zeros_in_grad, - ) - return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index e00b343271..3335e732c8 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -172,16 +172,24 @@ def __init__(self, self.num_generations = args.num_generations self.temperature = args.temperature + self.vllm_mode = args.vllm_mode + self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode + self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode self.loss_type = args.loss_type self.max_completion_length = args.max_completion_length self.completion_length_limit_scope = args.completion_length_limit_scope model.warnings_issued['estimate_tokens'] = True + kwargs['data_collator'] = identity_data_collator # No data collation is needed in GRPO self.shuffle_dataset = args.dataset_shuffle + + self.use_vllm = args.use_vllm + self.async_generate = args.async_generate + vllm_client = kwargs.pop('vllm_client') # for external vllm + self.model_kwarg_keys = ( inspect.signature(model.forward).parameters.keys() if not hasattr(model, 'get_base_model') else inspect.signature(model.get_base_model().forward).parameters.keys()) - self.vllm_client = kwargs.pop('vllm_client') chord_sft_dataset = kwargs.pop('chord_sft_dataset', None) super().__init__(model, ref_model, *_args, **kwargs) self.chord_sft_iterator = None @@ -238,7 +246,59 @@ def __init__(self, set_seed(args.seed, device_specific=True) if is_peft_model(self.model): self.parameter_groups, self.parameter_groups_no_lora = self.split_batches() - self._prepare_rollout_engine() + self.use_fast_infer = self.use_vllm # whether to use the PT backend + self.vllm_use_async_engine = False + self.enable_offload = False + self.use_gym_env = False + self.enable_server_multi_turn = False + # for multi-turn server, maybe the num of rollout outputs is not equal to the num of rollout inputs + self.dynamic_num_samples = False + if self.use_vllm: + if not is_vllm_available(): + raise ImportError('vLLM is not available and `use_vllm` is set to True. ' + 'Please install vLLM with `pip install vllm -U` to use it.') + if self.vllm_mode == 'server': + self.vllm_client: VLLMClient = vllm_client + if self.accelerator.is_main_process: + self.vllm_client.get_engine_type() + vllm_use_async_engine = [self.vllm_client.use_async_engine] + use_gym_env = [self.vllm_client.use_gym_env] + enable_multi_turn = [self.vllm_client.enable_multi_turn] + else: + vllm_use_async_engine = [False] + use_gym_env = [False] + enable_multi_turn = [self.enable_server_multi_turn] + self.vllm_use_async_engine = broadcast_object_list(vllm_use_async_engine, from_process=0)[0] + self.use_gym_env = broadcast_object_list(use_gym_env, from_process=0)[0] + self.enable_server_multi_turn = broadcast_object_list(enable_multi_turn, from_process=0)[0] + if self.use_gym_env: + self.reward_func_names = ['gym_reward'] + + elif self.vllm_mode == 'colocate': + if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0: + raise ValueError( + f'vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size ' + f'({self.accelerator.num_processes}) evenly.') + + if self.vllm_tensor_parallel_size > 1: + # Create subgroups of ranks for TP, each group with `vllm_tensor_parallel_size` ranks. + # For example, if world_size=8 and vllm_tensor_parallel_size=2 → groups: [0,1], [2,3], [4,5], [6,7] + self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration([ + list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size)) + for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size) + ]) + self.enable_offload = self.args.offload_model or self.args.offload_optimizer + context = self.offload_context if self.enable_offload else nullcontext + + with context(): + self.engine = self.prepare_vllm(model) + if self.args.sleep_level > 0: + self.engine.engine.sleep(self.args.sleep_level) + + else: + from swift.llm import PtEngine + self.engine = PtEngine.from_model_template(self.model, copy(self.template), max_batch_size=0) # 0: no limit + if not self.reward_funcs and not self.use_gym_env: raise ValueError('You must specify reward_funcs or reward_model') @@ -2832,64 +2892,3 @@ def _get_last_indices(self, request_ids: List[str]) -> torch.Tensor: for i, rid in enumerate(request_ids): seen[rid] = i return torch.tensor(list(seen.values()), dtype=torch.long, device=self.accelerator.device) - - def _prepare_rollout_engine(self, model): - args = self.args - self.vllm_mode = args.vllm_mode - self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode - self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode - self.use_vllm = args.use_vllm - self.async_generate = args.async_generate - vllm_client = getattr(args, 'vllm_client') or getattr(self, 'vllm_client') # for external vllm - self.use_fast_infer = self.use_vllm # whether to use the PT backend - self.vllm_use_async_engine = False - self.enable_offload = False - self.use_gym_env = False - self.enable_server_multi_turn = False - # for multi-turn server, maybe the num of rollout outputs is not equal to the num of rollout inputs - self.dynamic_num_samples = False - if self.use_vllm: - if not is_vllm_available(): - raise ImportError('vLLM is not available and `use_vllm` is set to True. ' - 'Please install vLLM with `pip install vllm -U` to use it.') - if self.vllm_mode == 'server': - self.vllm_client: VLLMClient = vllm_client - if self.accelerator.is_main_process: - self.vllm_client.get_engine_type() - vllm_use_async_engine = [self.vllm_client.use_async_engine] - use_gym_env = [self.vllm_client.use_gym_env] - enable_multi_turn = [self.vllm_client.enable_multi_turn] - else: - vllm_use_async_engine = [False] - use_gym_env = [False] - enable_multi_turn = [self.enable_server_multi_turn] - self.vllm_use_async_engine = broadcast_object_list(vllm_use_async_engine, from_process=0)[0] - self.use_gym_env = broadcast_object_list(use_gym_env, from_process=0)[0] - self.enable_server_multi_turn = broadcast_object_list(enable_multi_turn, from_process=0)[0] - if self.use_gym_env: - self.reward_func_names = ['gym_reward'] - - elif self.vllm_mode == 'colocate': - if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0: - raise ValueError( - f'vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size ' - f'({self.accelerator.num_processes}) evenly.') - - if self.vllm_tensor_parallel_size > 1: - # Create subgroups of ranks for TP, each group with `vllm_tensor_parallel_size` ranks. - # For example, if world_size=8 and vllm_tensor_parallel_size=2 → groups: [0,1], [2,3], [4,5], [6,7] - self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration([ - list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size)) - for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size) - ]) - self.enable_offload = self.args.offload_model or self.args.offload_optimizer - context = self.offload_context if self.enable_offload else nullcontext - - with context(): - self.engine = self.prepare_vllm(model) - if self.args.sleep_level > 0: - self.engine.engine.sleep(self.args.sleep_level) - - else: - from swift.llm import PtEngine - self.engine = PtEngine.from_model_template(self.model, copy(self.template), max_batch_size=0) # 0: no limit From 4d9457bbff5d602861228bc50a3a2639c3862ad9 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 18 Sep 2025 20:59:46 +0800 Subject: [PATCH 09/83] sync weight wip --- swift/megatron/train/rlhf.py | 9 +- swift/megatron/trainers/dpo_trainer.py | 2 +- swift/megatron/trainers/grpo_trainer.py | 241 +++++++++++++++++------- 3 files changed, 179 insertions(+), 73 deletions(-) diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index c8f3129b4b..e98dd3c2c8 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import List, Optional, Union +from swift.trainers.rlhf_trainer.utils import identity_data_collator from swift.utils import get_logger from ..argument import MegatronRLHFArguments from ..trainers import MegatronDPOTrainer, MegatronGRPOTrainer @@ -21,12 +22,18 @@ def prepare_trainer(self): trainer_cls = MegatronGRPOTrainer else: raise ValueError(f'The current Megatron-SWIFT does not support rlhf_type: {args.rlhf_type}.') - return trainer_cls(args) + return trainer_cls(args, self.template) def _prepare_template(self) -> None: super()._prepare_template() self.template.set_mode('rlhf') + def _get_data_collator(self): + if self.args.rlhf_type == 'grpo': + super()._get_data_collator() + return identity_data_collator + return super()._get_data_collator() + def megatron_rlhf_main(args: Optional[Union[List[str], MegatronRLHFArguments]] = None): return MegatronRLHF(args).main() diff --git a/swift/megatron/trainers/dpo_trainer.py b/swift/megatron/trainers/dpo_trainer.py index 87d6d26263..4e57642175 100644 --- a/swift/megatron/trainers/dpo_trainer.py +++ b/swift/megatron/trainers/dpo_trainer.py @@ -36,7 +36,7 @@ def __init__(self, args): class MegatronDPOTrainer(MegatronRLHFTrainer): - def __init__(self, args): + def __init__(self, args, template): super().__init__(args) self.dummy_dpo_trainer = DummyDPOTrainer(args) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index d570b4af0a..bfe329fd5b 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -3,8 +3,9 @@ import inspect from collections import namedtuple from contextlib import contextmanager, nullcontext +from copy import copy from functools import partial -from typing import Any, Dict +from typing import Any, Dict, List import torch from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed @@ -15,12 +16,14 @@ from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.training import get_args, get_model, get_timers, training from megatron.training.checkpointing import load_checkpoint -from megatron.training.training import (cuda_graph_capture, cuda_graph_set_manual_hooks, - get_tensor_shapes_adjust_fn_for_distillation, has_nvidia_modelopt) +from megatron.training.training import cuda_graph_capture, cuda_graph_set_manual_hooks from megatron.training.utils import (logical_and_across_model_parallel_group, reduce_max_stat_across_model_parallel_group, unwrap_model) from torch.distributed.nn import all_reduce +from vllm.distributed import parallel_state as vllm_ps +from swift.llm import RequestConfig +from swift.llm.infer.protocol import RolloutOutput from swift.plugin import orms from swift.trainers.rlhf_trainer import GRPOTrainer, VLLMClient from swift.utils import get_current_device, get_logger, is_vllm_available @@ -34,32 +37,53 @@ except ImportError: pass +try: + from megatron.post_training.algos.distillation import ( + get_tensor_shapes_adjust_fn_for_distillation, ) + + has_nvidia_modelopt = True +except ImportError: + has_nvidia_modelopt = False + logger = get_logger() +""" +TODO: + 1. compute rewards + 2. compute advantages + 3. compute ref/old logps + 4. loss + + +FUTURE TODO: + 1. server mode + 2. offload model/optimizer + 3. DAPO (dynamic sampling) + 4. entropy mask + 5. reward model + 6. multi turn +""" class MegatronGRPOTrainer(MegatronRLHFTrainer): def __init__(self, args: MegatronRLHFArguments, template): - MegatronRLHFTrainer().__init__(self, args) + super().__init__(args, template) self.args = args self.hf_model_dir = args.model_info.model_dir - self.process_index = torch.distributed.get_rank() self.processing_class = self.template.processor # set up reward funcs self.prepare_rewards() # TODO: multi turn scheduler(colocate multi turn) - self._prepare_rollout_engine(self.unwrapped_model) self._init_grpo_params() - if args.use_mbridge: - # debug: use mbridge to convert mcore to hf - from transformers import AutoConfig - hf_config = AutoConfig.from_pretrained(self.hf_model_dir, trust_remote_code=True) - bridge = AutoBridge.from_pretrained(hf_config) - self.bridge = bridge + self._prepare_rollout_engine() + + # debug: use mbridge to convert mcore to hf + self.bridge = None def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, labels: torch.Tensor, packed_seq_params): - args = get_args() + # TODO:GRPO policy loss + args: MegatronRLHFArguments = get_args() num_samples = packed_seq_params.num_samples logps = self.get_logps(output_tensor, labels, packed_seq_params) @@ -100,7 +124,18 @@ def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, lab return loss, reporting_metric def _replace_data_iterator(self, data_iterator): - args = get_args() + # rollout the train_batch_size, split to mini-batches + data = next(data_iterator) + data = self._generate_and_score_completions(data) + # step1: rollout + + # step2: compute old logps + + # step3: compute ref logps + + # step4: compute rewards/advantages + + args: MegatronRLHFArguments = get_args() num_iters_per_step = args.global_batch_size // (args.micro_batch_size * mpu.get_data_parallel_world_size()) res = [] with torch.no_grad(), self.null_ref_context() as ref_model: @@ -110,6 +145,11 @@ def _replace_data_iterator(self, data_iterator): def _init_grpo_params(self): args = self.args + # distributed params + self.world_size = torch.distributed.get_world_size() + self.process_index = torch.distributed.get_rank() + + # algorithm params self.num_generations = args.num_generations # G in the GRPO paper self.temperature = args.temperature self.loss_type = args.loss_type @@ -123,15 +163,25 @@ def _init_grpo_params(self): self.global_batch_size = args.global_batch_size self.mini_batch_size = args.mini_batch_size self.micro_batch_size = args.micro_batch_size - - def _prepare_rollout_engine(self, model): + self.per_device_rollout_batch_size = self.global_batch_size // self.world_size + # sampling params + self.request_config = RequestConfig( + n=1, + max_tokens=args.max_completion_length, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + repetition_penalty=args.repetition_penalty, + stop=args.stop_words, + return_details=True) + + def _prepare_rollout_engine(self): args = self.args self.vllm_mode = args.vllm_mode self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode self.use_vllm = args.use_vllm self.async_generate = args.async_generate - vllm_client = getattr(args, 'vllm_client') or getattr(self, 'vllm_client') # for external vllm self.use_fast_infer = self.use_vllm # whether to use the PT backend self.vllm_use_async_engine = False self.enable_offload = False @@ -143,56 +193,30 @@ def _prepare_rollout_engine(self, model): if not is_vllm_available(): raise ImportError('vLLM is not available and `use_vllm` is set to True. ' 'Please install vLLM with `pip install vllm -U` to use it.') - if self.vllm_mode == 'server': - self.vllm_client: VLLMClient = vllm_client - if self.accelerator.is_main_process: - self.vllm_client.get_engine_type() - vllm_use_async_engine = [self.vllm_client.use_async_engine] - use_gym_env = [self.vllm_client.use_gym_env] - enable_multi_turn = [self.vllm_client.enable_multi_turn] - else: - vllm_use_async_engine = [False] - use_gym_env = [False] - enable_multi_turn = [self.enable_server_multi_turn] - self.vllm_use_async_engine = broadcast_object_list(vllm_use_async_engine, from_process=0)[0] - self.use_gym_env = broadcast_object_list(use_gym_env, from_process=0)[0] - self.enable_server_multi_turn = broadcast_object_list(enable_multi_turn, from_process=0)[0] - if self.use_gym_env: - self.reward_func_names = ['gym_reward'] - - elif self.vllm_mode == 'colocate': - if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0: - raise ValueError( - f'vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size ' - f'({self.accelerator.num_processes}) evenly.') - - if self.vllm_tensor_parallel_size > 1: - # Create subgroups of ranks for TP, each group with `vllm_tensor_parallel_size` ranks. - # For example, if world_size=8 and vllm_tensor_parallel_size=2 → groups: [0,1], [2,3], [4,5], [6,7] - self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration([ - list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size)) - for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size) - ]) - self.enable_offload = self.args.offload_model or self.args.offload_optimizer - context = self.offload_context if self.enable_offload else nullcontext - - with context(): - self.engine = self.prepare_vllm() - if self.args.sleep_level > 0: - self.engine.engine.sleep(self.args.sleep_level) + assert self.vllm_mode == 'colocate' # TODO: server mode + + if not self.world_size % self.vllm_tensor_parallel_size == 0: + raise ValueError(f'vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size ' + f'({self.world_size}) evenly.') + + self.enable_offload = self.args.offload_model or self.args.offload_optimizer + context = self.offload_context if self.enable_offload else nullcontext + + with context(): + self.engine = self.prepare_vllm() + if self.args.sleep_level > 0: + self.engine.engine.sleep(self.args.sleep_level) def prepare_vllm(self): - from swift.tuners import Swift from swift.llm.infer.infer_engine import GRPOVllmEngine args = self.args - max_num_seqs = ( - self.args.per_device_train_batch_size * self.vllm_tensor_parallel_size * self.args.steps_per_generation) + max_num_seqs = self.per_device_rollout_batch_size * self.vllm_tensor_parallel_size engine = GRPOVllmEngine( self.hf_model_dir, args.torch_dtype, model_type=args.model_type, - use_async_engine=False, # TODO: async engine for colocate + use_async_engine=False, tensor_parallel_size=self.vllm_tensor_parallel_size, gpu_memory_utilization=self.vllm_gpu_memory_utilization, enable_prefix_caching=self.args.vllm_enable_prefix_caching, @@ -207,6 +231,8 @@ def prepare_vllm(self): template=self.template, distributed_executor_backend='external_launcher', ) + if self.vllm_tensor_parallel_size > 1: + self.vllm_tp_group = vllm_ps.get_tp_group().device_group return engine def prepare_rewards(self): @@ -240,20 +266,13 @@ def prepare_rewards(self): def _move_model_to_vllm(self): # TODO: LoRA, server - per_tensor_params = self.bridge.export_weights(self.unwrapped_model) - self.engine.inner_model.load_weights(per_tensor_params) + if self.bridge is None: + self.bridge = AutoBridge.from_pretrained(self.hf_model_dir) + per_tensor_params = self.bridge.export_weights([self.unwrapped_model]) + self.engine.inner_model.load_weights(per_tensor_params) # TODO: check tensor_model_parallel def forward_step(self, data_iterator, model): # train_batch_size - - # step1: rollout - - # step2: compute old logps - - # step3: compute ref logps - - # step4: compute rewards/advantages - # return: output_tensor, loss_func data = next(data_iterator) @@ -272,9 +291,16 @@ def _patch_megatron(self): def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): # borrowed from Megatron-LM 0.13 - """Single training step.""" - args = get_args() + # get train_batch_size Rollout / ref/old logps / reward / advantage + # split to mini_batches (iter mini_batch) + data_iterator = self._replace_data_iterator(data_iterator) + + args: MegatronRLHFArguments = get_args() timers = get_timers() + batch = next(data_iterator) + batch = self._generate_and_score_completions(batch) + + # split to mini-batches # CUDA Graph capturing only executes once, when it's the first training iteration. if args.curr_iteration == args.iteration and args.external_cuda_graph: @@ -404,3 +430,76 @@ def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_par num_zeros_in_grad, ) return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad + + def _generate_and_score_completions(self, batch): + batch = self._generate_completions(batch) + # total_rewards_per_func = self._score_completions(batch) + # TODO: dynamic sampling + # total_advantages = self._compute_advantages(batch, total_rewards_per_func) + # batch = self._prepare_batch_inputs(batch) + return batch + + def _generate_completions(self, batch): + # TODO: server mode + assert self.vllm_mode == 'colocate' + # Step 1: Wake up the engine if it's sleeping (vLLM colocate mode) + if self.engine.inner_model_executor.is_sleeping: + wake_up_params = inspect.signature(self.engine.engine.wake_up).parameters + # Load weights only (faster and reduces memory peak) + kwargs = {'tags': ['weights']} if 'tags' in wake_up_params else {} + self.engine.engine.wake_up(**kwargs) + + # Step 2: Load model weights + self._move_model_to_vllm() + + if (self.engine.inner_model_executor.is_sleeping + and 'tags' in inspect.signature(self.engine.engine.wake_up).parameters): + self.engine.engine.wake_up(tags=['kv_cache']) + + batch = self.preprocess_rollout_data(batch) + output: List[RolloutOutput] = self._rollout(batch) + batch = self.postprocess_rollout_data(output) + return batch + + def preprocess_rollout_data(self, batch): + if self.vllm_tensor_parallel_size == 1: + return batch + + gathered_batch = [None for _ in range(self.vllm_tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_batch, batch, group=self.vllm_tp_group) + flattened_batch = [p for sublist in gathered_batch for p in sublist] + return flattened_batch + + def _rollout(self, batch) -> List[RolloutOutput]: + request_config = self._get_request_config() + # TODO: server mode + rollout_outputs = self._colocate_rollout(batch, request_config) + return rollout_outputs + + def postprocess_rollout_data(self, batch): + if self.vllm_tensor_parallel_size == 1: + return batch + local_rank_in_group = torch.distributed.get_rank(group=self.vllm_tp_group) + orig_size = len(batch) // self.vllm_tensor_parallel_size + tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) + return batch[tp_slice] + + def _get_request_config(self) -> RequestConfig: + request_config = copy(self.request_config) + if self.args.vllm_mode == 'colocate' and self.vllm_tensor_parallel_size > 1: + # Set request_config.seed + # 1. Ensure that the seed for vLLM Engines within each TP (Tensor Parallelism) group is the same; + # otherwise, the program may hang. + # 2. Ensure that the seed for vLLM Engines across different TP groups is different; + # otherwise, identical completions will be generated. + batch_size = self.per_device_rollout_batch_size + batch_size *= self.vllm_tensor_parallel_size + # Since the TP (Tensor Parallelism) group gathers the inputs, + # multiply the batch size by the TP parallel size. + request_config.seed = batch_size * (self.process_index // self.vllm_tensor_parallel_size) + + return request_config + + def _colocate_rollout(self, batch, request_config: RequestConfig): + outputs: List[RolloutOutput] = self.engine.infer(infer_requests=batch, request_config=request_config) + return outputs From f52d5e108cf8469e6615fb6a5379938f5b9bec44 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 19 Sep 2025 18:05:59 +0800 Subject: [PATCH 10/83] rollout wip --- swift/megatron/argument/megatron_args.py | 3 +- swift/megatron/model/gpt_model.py | 74 ++++++++ swift/megatron/trainers/base.py | 3 +- swift/megatron/trainers/grpo_trainer.py | 226 +++++++++++++++++++---- swift/megatron/trainers/utils.py | 19 +- 5 files changed, 291 insertions(+), 34 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 658494a314..c2f399016b 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -17,6 +17,7 @@ @dataclass class RLHFMegatronArgumentsMixin: + perform_initialization: bool = True rlhf_type: Literal['dpo', 'grpo'] = 'dpo' ref_load: Optional[str] = None ref_adapter_load: Optional[str] = None @@ -354,7 +355,7 @@ class MegatronArguments(ExtraMegatronArguments): no_load_rng: bool = False finetune: bool = False ckpt_format: Literal['torch', 'torch_dist', 'zarr'] = 'torch_dist' - no_initialization: bool = True + no_initialization: bool = False auto_detect_ckpt_format: bool = True exit_on_missing_checkpoint: bool = True diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index f03a7c855e..829888006e 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -91,6 +91,80 @@ def __init__( logger.warning('`apply_rope_fusion` does not support `attention_scaling`. ' f'Setting `config.apply_rope_fusion`: {config.apply_rope_fusion}') + # Set tensor_model_parallel attributes for all parameters + # This is needed for mbridge to correctly identify TP parameters + # self._set_tensor_model_parallel_attributes() + + def _set_tensor_model_parallel_attributes(self): + """Set tensor_model_parallel attributes for all parameters. + + This method ensures that all parameters have the correct tensor_model_parallel + attributes set, which is required for mbridge to correctly identify TP parameters + during weight export. + """ + from megatron.core.tensor_parallel.layers import set_tensor_model_parallel_attributes + + # Get tensor parallel size + from megatron.core import parallel_state + tp_size = parallel_state.get_tensor_model_parallel_world_size() + + if tp_size <= 1: + return # No tensor parallelism, no need to set attributes + + # Set attributes for all parameters + for name, param in self.named_parameters(): + if not hasattr(param, 'tensor_model_parallel'): + # Determine if this parameter should be tensor parallel + is_tp_param = self._is_tensor_parallel_parameter(name, param) + if is_tp_param: + # Determine partition dimension based on parameter name + partition_dim = self._get_partition_dimension(name, param) + set_tensor_model_parallel_attributes(param, True, partition_dim, 1) + else: + # Set default attributes for non-TP parameters + setattr(param, 'tensor_model_parallel', False) + setattr(param, 'partition_dim', -1) + setattr(param, 'partition_stride', 1) + + def _is_tensor_parallel_parameter(self, name: str, param) -> bool: + """Determine if a parameter should be tensor parallel based on its name and shape.""" + # Parameters that are typically tensor parallel + tp_patterns = [ + 'weight', # Linear layer weights + 'qkv_proj.weight', # QKV projection weights + 'dense.weight', # Dense layer weights + 'fc1.weight', # MLP first layer weights + 'fc2.weight', # MLP second layer weights + 'gate_proj.weight', # Gate projection weights + 'up_proj.weight', # Up projection weights + 'down_proj.weight', # Down projection weights + ] + + # Check if parameter name matches any TP pattern + for pattern in tp_patterns: + if pattern in name: + return True + + # Special cases for bias parameters in TP layers + if 'bias' in name and any(pattern in name for pattern in tp_patterns): + return True + + return False + + def _get_partition_dimension(self, name: str, param) -> int: + """Get the partition dimension for a tensor parallel parameter.""" + # Column parallel layers (partition along output dimension) + if any(pattern in name + for pattern in ['qkv_proj.weight', 'gate_proj.weight', 'up_proj.weight', 'fc1.weight', 'dense.weight']): + return 0 # Partition along output dimension + + # Row parallel layers (partition along input dimension) + if any(pattern in name for pattern in ['down_proj.weight', 'fc2.weight']): + return 1 # Partition along input dimension + + # Default to partition along output dimension + return 0 + @contextmanager def _patch_apply_rotary_pos_emb(self): if self.attention_scaling == 1.: diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 9d217cf03f..57f9d96c83 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -26,6 +26,7 @@ from megatron.training.utils import reduce_max_stat_across_model_parallel_group, report_memory from packaging import version +from swift.llm import Template from swift.plugin import MeanMetric from swift.trainers import SwiftMixin from swift.utils import JsonlWriter, deep_getattr, format_time, get_logger @@ -37,7 +38,7 @@ class BaseMegatronTrainer(ABC): - def __init__(self, args, template): + def __init__(self, args, template: Template): self.args = args self.template = template self.stimer = StragglerDetector() diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index bfe329fd5b..fd4acadbac 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -5,9 +5,10 @@ from contextlib import contextmanager, nullcontext from copy import copy from functools import partial -from typing import Any, Dict, List +from typing import Any, Dict, List, Union import torch +import torch.nn as nn from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed from megatron.core import mpu from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank @@ -22,15 +23,17 @@ from torch.distributed.nn import all_reduce from vllm.distributed import parallel_state as vllm_ps -from swift.llm import RequestConfig +from swift.llm import RequestConfig, RowPreprocessor, Template, to_device from swift.llm.infer.protocol import RolloutOutput from swift.plugin import orms from swift.trainers.rlhf_trainer import GRPOTrainer, VLLMClient -from swift.utils import get_current_device, get_logger, is_vllm_available +from swift.trainers.rlhf_trainer.grpo_trainer import DataType +from swift.trainers.rlhf_trainer.utils import replace_assistant_response_with_ids +from swift.utils import get_current_device, get_logger, is_vllm_available, remove_response from ..argument import MegatronRLHFArguments from .rlhf_base import MegatronRLHFTrainer from .trainer import MegatronTrainer -from .utils import get_batch +from .utils import get_batch, profiling_context try: from mbridge import AutoBridge @@ -66,15 +69,14 @@ class MegatronGRPOTrainer(MegatronRLHFTrainer): - def __init__(self, args: MegatronRLHFArguments, template): + def __init__(self, args: MegatronRLHFArguments, template: Template): super().__init__(args, template) self.args = args self.hf_model_dir = args.model_info.model_dir self.processing_class = self.template.processor - # set up reward funcs - self.prepare_rewards() # TODO: multi turn scheduler(colocate multi turn) self._init_grpo_params() + self._prepare_rewards() self._prepare_rollout_engine() # debug: use mbridge to convert mcore to hf @@ -125,15 +127,9 @@ def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, lab def _replace_data_iterator(self, data_iterator): # rollout the train_batch_size, split to mini-batches - data = next(data_iterator) - data = self._generate_and_score_completions(data) - # step1: rollout - - # step2: compute old logps + batch = next(data_iterator) # [global_batch_size, ] - # step3: compute ref logps - - # step4: compute rewards/advantages + batch = self._generate_and_score_completions(batch) args: MegatronRLHFArguments = get_args() num_iters_per_step = args.global_batch_size // (args.micro_batch_size * mpu.get_data_parallel_world_size()) @@ -148,7 +144,8 @@ def _init_grpo_params(self): # distributed params self.world_size = torch.distributed.get_world_size() self.process_index = torch.distributed.get_rank() - + self.is_main_process = self.process_index == 0 + self.device = get_current_device() # algorithm params self.num_generations = args.num_generations # G in the GRPO paper self.temperature = args.temperature @@ -163,7 +160,7 @@ def _init_grpo_params(self): self.global_batch_size = args.global_batch_size self.mini_batch_size = args.mini_batch_size self.micro_batch_size = args.micro_batch_size - self.per_device_rollout_batch_size = self.global_batch_size // self.world_size + self.per_device_rollout_batch_size = self.global_batch_size // self.world_size * self.num_generations # sampling params self.request_config = RequestConfig( n=1, @@ -235,7 +232,7 @@ def prepare_vllm(self): self.vllm_tp_group = vllm_ps.get_tp_group().device_group return engine - def prepare_rewards(self): + def _prepare_rewards(self): # TODO: reward model args = self.args reward_funcs = args.reward_funcs @@ -264,6 +261,14 @@ def prepare_rewards(self): reward_func_name = reward_func.__class__.__name__ self.reward_func_names.append(reward_func_name) + if args.reward_weights is not None: + if len(args.reward_weights) != len(reward_funcs): + raise ValueError(f'Number of reward weights ({len(args.reward_weights)}) must match number of reward ' + f'functions ({len(reward_funcs)})') + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32).to(self.device) + else: + self.reward_weights = torch.ones(len(self.reward_func_names), dtype=torch.float32).to(self.device) + def _move_model_to_vllm(self): # TODO: LoRA, server if self.bridge is None: @@ -432,11 +437,26 @@ def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_par return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad def _generate_and_score_completions(self, batch): - batch = self._generate_completions(batch) - # total_rewards_per_func = self._score_completions(batch) - # TODO: dynamic sampling - # total_advantages = self._compute_advantages(batch, total_rewards_per_func) - # batch = self._prepare_batch_inputs(batch) + + def get_local_rollout_data(batch): + # repeat num_generations times + global_rollout_data = [item for item in batch for _ in range(self.num_generations)] + # get local rollout data + data_slice = slice(self.process_index * self.per_device_rollout_batch_size, + (self.process_index + 1) * self.per_device_rollout_batch_size) + rollout_data = global_rollout_data[data_slice] + return rollout_data + + rollout_data = get_local_rollout_data(batch) + + batch = self._generate_completions(rollout_data) + + rewards_per_func = self._score_completions(batch) + + batch['advantages'] = self._compute_advantages(batch, rewards_per_func) + + batch = self._maybe_compute_logps(batch) + return batch def _generate_completions(self, batch): @@ -458,7 +478,7 @@ def _generate_completions(self, batch): batch = self.preprocess_rollout_data(batch) output: List[RolloutOutput] = self._rollout(batch) - batch = self.postprocess_rollout_data(output) + batch = self.postprocess_rollout_data(batch, output) return batch def preprocess_rollout_data(self, batch): @@ -476,13 +496,48 @@ def _rollout(self, batch) -> List[RolloutOutput]: rollout_outputs = self._colocate_rollout(batch, request_config) return rollout_outputs - def postprocess_rollout_data(self, batch): - if self.vllm_tensor_parallel_size == 1: - return batch - local_rank_in_group = torch.distributed.get_rank(group=self.vllm_tp_group) - orig_size = len(batch) // self.vllm_tensor_parallel_size - tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - return batch[tp_slice] + def postprocess_rollout_data(self, batch, output): + if self.vllm_tensor_parallel_size > 1: + local_rank_in_group = torch.distributed.get_rank(group=self.vllm_tp_group) + orig_size = len(output) // self.vllm_tensor_parallel_size + tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) + output = output[tp_slice] + + def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], output: RolloutOutput): + response = output.response + choice = response.choices[0] + + # Step 1: Update or append assistant message + if output.messages: + input_data['messages'] = output.messages # Override full message history + else: + # not provided, append + messages = input_data['messages'] + remove_response(messages) + messages.append({'role': 'assistant', 'content': choice.message.content}) + + # Step 2: Add token IDs and loss mask + if output.response_token_ids: + input_data['response_token_ids'] = output.response_token_ids + if output.response_loss_mask: + input_data['response_loss_mask'] = output.response_loss_mask + else: + if not self.multi_turn_scheduler: + # for single turn, skip tokenizer response + input_data['response_token_ids'] = output.response.choices[0].token_ids + + # Step 3: Attach rollout extra info + if output.rollout_infos: + input_data['rollout_infos'] = output.rollout_infos + + # Step 4: Store finish reason (used for truncation filters etc.) + input_data['finish_reason'] = choice.finish_reason + input_data['is_truncated'] = choice.finish_reason == 'length' + + return input_data + + assert len(batch) == len(output) + return [merge_output_input_data(input_data, output) for input_data, output in zip(batch, output)] def _get_request_config(self) -> RequestConfig: request_config = copy(self.request_config) @@ -503,3 +558,112 @@ def _get_request_config(self) -> RequestConfig: def _colocate_rollout(self, batch, request_config: RequestConfig): outputs: List[RolloutOutput] = self.engine.infer(infer_requests=batch, request_config=request_config) return outputs + + def _score_completions(self, inputs: DataType) -> torch.Tensor: + """Score completions using all reward functions. + + Args: + inputs: List of input examples, each containing a 'messages' list with conversation history + + Returns: + rewards_per_func: Tensor of shape (num_examples, num_reward_funcs) with local reward values + """ + # Compute rewards using reward functions + local_rewards_per_func = self._compute_rewards_per_func(inputs) + + return local_rewards_per_func + + def _compute_rewards_per_func(self, batch: DataType) -> torch.Tensor: + """Compute rewards using all reward functions""" + device = self.accelerator.device + rewards_per_func = torch.zeros((len(batch), len(self.reward_funcs)), device=device) + completions = [inp['messages'][-1]['content'] for inp in batch] + for i, (reward_func, reward_model_plugin, reward_func_name) in enumerate( + zip(self.reward_funcs, self.reward_model_plugins, self.reward_func_names)): + with profiling_context(self, reward_func_name): + # reward model + reward_kwargs = {} # TODO: step info + if isinstance(reward_func, nn.Module): + output_reward_func = reward_model_plugin(inputs=batch, **reward_kwargs) + # reward function + else: + # Repeat all input columns (but "messages" and "completion") to match the number of generations + reward_kwargs.update(RowPreprocessor.rows_to_batched(batch)) + output_reward_func = reward_func(completions, **reward_kwargs) + output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + # If all reward functions return None for a given row, issue a detailed warning + if torch.isnan(rewards_per_func).all(dim=1).any(): + nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] + row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()} + row_reward_kwargs['completion'] = completions[nan_row_idx] + logger.warning(f'All reward functions returned None for the following kwargs: {row_reward_kwargs}. ' + 'Please ensure that at least one reward function returns a valid reward.') + + return rewards_per_func + + def _compute_advantages(self, batch: DataType, rewards_per_func: torch.Tensor) -> torch.Tensor: + """Compute advantages for RL training.""" + + def maybe_normalize_advantages(advantages: torch.Tensor, rewards_std: torch.Tensor) -> torch.Tensor: + """Normalize advantages if configured; otherwise, return as-is.""" + if self.args.scale_rewards: + return advantages / (rewards_std + 1e-4) + return advantages + + total_rewards_per_func = gather(rewards_per_func) + rewards = (total_rewards_per_func * self.reward_weights.unsqueeze(0)).nansum(dim=1) + grouped_rewards = rewards.view(-1, self.num_generations) + group_rewards_mean = grouped_rewards.mean(dim=1) + group_rewards_std = grouped_rewards.std(dim=1) + + # Broadcast stats back to the original shape + group_rewards_mean = group_rewards_mean.repeat_interleave(self.num_generations) + group_rewards_std = group_rewards_std.repeat_interleave(self.num_generations) + + # Compute advantages relative to group mean + advantages = rewards - group_rewards_mean + advantages = maybe_normalize_advantages(advantages, group_rewards_std) + + slice_start = self.process_index * len(batch) + slice_end = slice_start + len(batch) + advantages = advantages[slice_start:slice_end] + + return advantages + + def _maybe_compute_logps(self, batch: DataType) -> DataType: + # encode first + template = self.template + # get model forward kwargs from batch + batch = self._maybe_replace_response_token(batch) + with self._disable_maxlength_template_context(template): + batch_encoded_inputs = [template.encode(data, return_length=True) for data in batch] + batch_encoded_inputs = to_device(template.data_collator(batch_encoded_inputs), self.device) + labels = batch_encoded_inputs.pop('labels') + logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item() + batch_encoded_inputs['logits_to_keep'] = logits_to_keep + + return batch + + @contextmanager + def _disable_maxlength_template_context(self, template: Template): + # The max_length for prompt and completion has already been restricted, so there is no need for max_length here. + max_length = template.max_length + template.max_length = None + try: + yield + finally: + template.max_length = max_length + + def _maybe_replace_response_token(self, batch): + # maybe replace the response token with the response token ids to avoid repetitive tokenize + for data in batch: + if 'response_token_ids' in data and data['response_token_ids']: + loss_mask = None + if 'response_loss_mask' in data and data['response_loss_mask']: + loss_mask = data['response_loss_mask'] + # token in token out + data['messages'] = replace_assistant_response_with_ids(data['messages'], data['response_token_ids'], + loss_mask) + return batch diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 4f40f6023e..e864cb636c 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -1,11 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import time +from contextlib import contextmanager from typing import Any, Dict import torch from megatron.core import mpu from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.utils import get_batch_on_this_cp_rank as mcore_get_batch_on_this_cp_rank -from megatron.training import get_args +from megatron.training import get_args, get_wandb_writer from swift.llm import get_packed_seq_params as _get_packed_seq_params from swift.llm import to_device @@ -150,3 +152,18 @@ def get_batch(data_iterator): # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) return batch + + +@contextmanager +def profiling_context(trainer, name: str): + start_time = time.perf_counter() + yield + end_time = time.perf_counter() + duration = end_time - start_time + + profiling_metrics = {f'profiling/Time taken: {trainer.__class__.__name__}.{name}': duration} + wandb_writer = get_wandb_writer() + if wandb_writer and trainer.is_main_process: + wandb_writer.log(profiling_metrics) + + # TODO: add swanlab support From 3c69c3937c3b68f84858553e695dc9bee2d6264b Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 22 Sep 2025 20:19:43 +0800 Subject: [PATCH 11/83] modify mini_batch_size to generation batch size --- swift/megatron/argument/megatron_args.py | 24 ++- swift/megatron/argument/train_args.py | 3 +- swift/megatron/trainers/dpo_trainer.py | 14 +- swift/megatron/trainers/grpo_trainer.py | 217 ++++++++++++++++------- swift/megatron/trainers/rlhf_base.py | 8 +- swift/megatron/trainers/utils.py | 25 ++- 6 files changed, 204 insertions(+), 87 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 4fa1f2abea..61aaad1f01 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -30,8 +30,9 @@ class RLHFMegatronArgumentsMixin: loss_type: str = 'sigmoid' # =========================== GRPO =========================== + generation_batch_size: Optional[int] = None + steps_per_generation: Optional[int] = None num_generations: int = 8 - mini_batch_size: int = 4 max_completion_length: int = 512 # ─────────────────────────── Sampling ─────────────────────────── @@ -121,9 +122,6 @@ class RLHFMegatronArgumentsMixin: importance_sampling_level: Literal['token', 'sequence', 'sequence_token'] = 'token' wandb_log_unique_prompts: Optional[bool] = None - generation_batch_size: Optional[int] = None - steps_per_generation: Optional[int] = None - num_iterations: int = 1 # dataset @@ -206,6 +204,24 @@ def _check_not_supported(): self.vllm_mode = 'colocate' logger.warning('set vllm_mode to `colocate` since vllm_server_host is not provided') + if self.generation_batch_size is None and self.steps_per_generation is None: + self.steps_per_generation = 1 + self.generation_batch_size = self.global_batch_size * self.steps_per_generation + elif self.generation_batch_size is not None and self.steps_per_generation is None: + # Just ensure the value is divisible by the global batch size + if self.generation_batch_size % self.global_batch_size != 0: + raise ValueError( + f'generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size ' + f'({self.global_batch_size}).') + self.steps_per_generation = self.generation_batch_size // self.global_batch_size + elif self.generation_batch_size is None and self.steps_per_generation is not None: + self.generation_batch_size = self.global_batch_size * self.steps_per_generation + else: + raise ValueError( + "'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time") + world_size = torch.distributed.get_world_size() + self.per_device_generation_batch_size = self.generation_batch_size // world_size + @dataclass class MegatronTunerMixin: diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py index 9bf4603c20..94d38dc327 100644 --- a/swift/megatron/argument/train_args.py +++ b/swift/megatron/argument/train_args.py @@ -9,7 +9,7 @@ from swift.llm.argument.base_args import to_abspath from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_master from ..model import get_megatron_model_meta -from .megatron_args import MegatronArguments +from .megatron_args import MegatronArguments, RLHFMegatronArgumentsMixin logger = get_logger() @@ -29,6 +29,7 @@ def init_model_args(self, tokenizer, config): if getattr(self, k) is None: setattr(self, k, v) MegatronArguments.__post_init__(self) + RLHFMegatronArgumentsMixin.__post_init__(self) self.extra_args = self.parse_to_megatron() self.extra_args['model_info'] = self.model_info self.extra_args['model_meta'] = self.model_meta diff --git a/swift/megatron/trainers/dpo_trainer.py b/swift/megatron/trainers/dpo_trainer.py index 145977f5f4..72d57e50b4 100644 --- a/swift/megatron/trainers/dpo_trainer.py +++ b/swift/megatron/trainers/dpo_trainer.py @@ -84,16 +84,6 @@ def _forward_step_helper(model, inputs): return output_tensor - def ref_forward(self, ref_model, data_iterator): - with self.stimer(bdata=True): - data = get_batch(data_iterator) - data.pop('loss_scale', None) - labels = data.get('labels') - with torch.no_grad(): - output_tensor = self._forward_step_helper(ref_model, data) - data['logps'] = None if labels is None else self.get_logps(output_tensor, labels, data['packed_seq_params']) - return data - @staticmethod def get_logps(output_tensor, labels, packed_seq_params): args = get_args() @@ -147,7 +137,7 @@ def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, lab # fix megatron-lm bug # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.12.0/megatron/core/pipeline_parallel/schedules.py#L291 loss = loss / mpu.get_context_parallel_world_size() - return loss, reporting_metric + return loss, metric def _replace_data_iterator(self, data_iterator): args = get_args() @@ -155,7 +145,7 @@ def _replace_data_iterator(self, data_iterator): res = [] with torch.no_grad(), self.null_ref_context() as ref_model: for i in range(num_iters_per_step): - res.append(self.ref_forward(ref_model, data_iterator)) + res.append(self.model_forward(ref_model, data_iterator)) return iter(res) def forward_step(self, data_iterator, model): diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index fd4acadbac..1ce614ce90 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -33,7 +33,7 @@ from ..argument import MegatronRLHFArguments from .rlhf_base import MegatronRLHFTrainer from .trainer import MegatronTrainer -from .utils import get_batch, profiling_context +from .utils import gather_tensor_dict, get_batch, make_batch_generator, profiling_context try: from mbridge import AutoBridge @@ -49,22 +49,6 @@ has_nvidia_modelopt = False logger = get_logger() -""" -TODO: - 1. compute rewards - 2. compute advantages - 3. compute ref/old logps - 4. loss - - -FUTURE TODO: - 1. server mode - 2. offload model/optimizer - 3. DAPO (dynamic sampling) - 4. entropy mask - 5. reward model - 6. multi turn -""" class MegatronGRPOTrainer(MegatronRLHFTrainer): @@ -125,20 +109,6 @@ def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, lab loss = loss / mpu.get_context_parallel_world_size() return loss, reporting_metric - def _replace_data_iterator(self, data_iterator): - # rollout the train_batch_size, split to mini-batches - batch = next(data_iterator) # [global_batch_size, ] - - batch = self._generate_and_score_completions(batch) - - args: MegatronRLHFArguments = get_args() - num_iters_per_step = args.global_batch_size // (args.micro_batch_size * mpu.get_data_parallel_world_size()) - res = [] - with torch.no_grad(), self.null_ref_context() as ref_model: - for i in range(num_iters_per_step): - res.append(self.ref_forward(ref_model, data_iterator)) - return iter(res) - def _init_grpo_params(self): args = self.args # distributed params @@ -148,6 +118,7 @@ def _init_grpo_params(self): self.device = get_current_device() # algorithm params self.num_generations = args.num_generations # G in the GRPO paper + self.beta = args.beta self.temperature = args.temperature self.loss_type = args.loss_type self.max_completion_length = args.max_completion_length @@ -157,10 +128,11 @@ def _init_grpo_params(self): self.importance_sampling_level = args.importance_sampling_level self.enable_offload = False # batch size + self.generation_batch_size = args.generation_batch_size + self.steps_per_generation = args.steps_per_generation self.global_batch_size = args.global_batch_size - self.mini_batch_size = args.mini_batch_size self.micro_batch_size = args.micro_batch_size - self.per_device_rollout_batch_size = self.global_batch_size // self.world_size * self.num_generations + self.per_device_generation_batch_size = args.per_device_generation_batch_size # sampling params self.request_config = RequestConfig( n=1, @@ -183,7 +155,7 @@ def _prepare_rollout_engine(self): self.vllm_use_async_engine = False self.enable_offload = False self.use_gym_env = False - self.enable_server_multi_turn = False + self.enable_server_multi_turn = False # TODO # for multi-turn server, maybe the num of rollout outputs is not equal to the num of rollout inputs self.dynamic_num_samples = False if self.use_vllm: @@ -207,7 +179,7 @@ def _prepare_rollout_engine(self): def prepare_vllm(self): from swift.llm.infer.infer_engine import GRPOVllmEngine args = self.args - max_num_seqs = self.per_device_rollout_batch_size * self.vllm_tensor_parallel_size + max_num_seqs = self.per_device_generation_batch_size * self.vllm_tensor_parallel_size * self.num_generations engine = GRPOVllmEngine( self.hf_model_dir, @@ -230,6 +202,7 @@ def prepare_vllm(self): ) if self.vllm_tensor_parallel_size > 1: self.vllm_tp_group = vllm_ps.get_tp_group().device_group + self._buffered_inputs = None return engine def _prepare_rewards(self): @@ -294,6 +267,21 @@ def _patch_megatron(self): super()._patch_megatron() self._origin_train_step = self.train_step + def _replace_data_iterator(self, data_iterator): + + args = get_args() + if args.iteration % self.steps_per_generation == 0: + # gradient_accumulation_steps + num_iters_per_step = args.global_batch_size // (args.micro_batch_size * mpu.get_data_parallel_world_size()) + # prepare generation batch data + rollout_batch = [] + for _ in range(self.steps_per_generation): + for _ in range(num_iters_per_step): + rollout_batch.extend(next(data_iterator)) + self._buffered_inputs = self._generate_and_score_completions(rollout_batch) + inputs = self._buffered_inputs[args.iteration % self.steps_per_generation] + return make_batch_generator(inputs, batch_size=self.micro_batch_size) + def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): # borrowed from Megatron-LM 0.13 # get train_batch_size Rollout / ref/old logps / reward / advantage @@ -302,8 +290,6 @@ def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_par args: MegatronRLHFArguments = get_args() timers = get_timers() - batch = next(data_iterator) - batch = self._generate_and_score_completions(batch) # split to mini-batches @@ -437,30 +423,100 @@ def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_par return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad def _generate_and_score_completions(self, batch): - - def get_local_rollout_data(batch): + # batch : same across DP groups + def get_local_rollout_batch(batch): # repeat num_generations times - global_rollout_data = [item for item in batch for _ in range(self.num_generations)] + global_rollout_batch = [item for item in batch for _ in range(self.num_generations)] # get local rollout data - data_slice = slice(self.process_index * self.per_device_rollout_batch_size, - (self.process_index + 1) * self.per_device_rollout_batch_size) - rollout_data = global_rollout_data[data_slice] - return rollout_data + # TODO: check do we should set with_context_parallel? debug with CP > 1 + data_parallel_size = mpu.get_data_parallel_world_size() - rollout_data = get_local_rollout_data(batch) + dp_local_rank = self.process_index % data_parallel_size + dp_group_size = self.world_size // data_parallel_size + assert dp_group_size * self.per_device_generation_batch_size * self.num_generations == len( + global_rollout_batch) + per_device_batch_size = self.per_device_generation_batch_size * self.num_generations + data_slice = slice(dp_local_rank * per_device_batch_size, (dp_local_rank + 1) * per_device_batch_size) + rollout_batch = global_rollout_batch[data_slice] + return rollout_batch - batch = self._generate_completions(rollout_data) + # Step1: get local rollout data in DP group + # rollout_batch : repeat num_generations times, get current process rollout data - rewards_per_func = self._score_completions(batch) + rollout_batch = get_local_rollout_batch(batch) - batch['advantages'] = self._compute_advantages(batch, rewards_per_func) + rollout_batch = self._generate_completions(rollout_batch) - batch = self._maybe_compute_logps(batch) + rewards_per_func = self._score_completions(rollout_batch) - return batch + advantages = self._compute_advantages(rollout_batch, rewards_per_func) + + def _get_encoded_batch(rollout_batch): + template = self.template + encoded_batch = [template.encode(data, return_length=True) for data in rollout_batch] + encoded_batch = to_device(template.data_collator(encoded_batch), self.device) + labels = encoded_batch.pop('labels') + logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item() + if self.template.padding_free: + position_ids = encoded_batch.get('text_position_ids') + if position_ids is None: + position_ids = encoded_batch.get('position_ids') + position_ids = position_ids.squeeze() + assert position_ids is not None + + lengths = torch.diff( + torch.cat([(position_ids == 0).nonzero(as_tuple=True)[0], + torch.tensor([len(position_ids)]).to(position_ids.device)])) + nonlocal advantages + advantages = torch.repeat_interleave(advantages, lengths) + + encoded_batch.update({ + 'completion_mask': + labels[:, -logits_to_keep:] != -100, + 'truncated_mask': + torch.tensor([b['is_truncated'] for b in rollout_batch], dtype=torch.bool), + 'advantages': + advantages, + 'position_ids': + position_ids # remove it: non-padding-free + }) + + # Step2: gather in DP group, model forward to get ref/old logps + # prepare model forward kwargs + encoded_batches = [] # [self.steps_per_generation, ] + for _ in range(self.steps_per_generation): + encoded_batch = _get_encoded_batch(rollout_batch) + encoded_batches.append(encoded_batch) + + dp_group = mpu.get_data_parallel_group(with_context_parallel=True) + gathered_encoded_batches = [] # [self.steps_per_generation, ] + for encoded_batch in encoded_batches: + gathered_encoded_batch = gather_tensor_dict(encoded_batch, group=dp_group) + gathered_encoded_batch = self._maybe_compute_logps(gathered_encoded_batch) + gathered_encoded_batches.append(gathered_encoded_batch) + + return gathered_encoded_batches def _generate_completions(self, batch): + """ + Generate completions for a batch of rollout data using vLLM engine. + + This method processes rollout data for the current process, generates completions + using the vLLM engine, and merges the results back into the original batch. + + Args: + batch: Rollout data assigned to the current process. Expected size is + per_device_generation_batch_size. + + Returns: + batch: The input batch with rollout completion results merged in. + + Note: + Currently only supports colocate mode. Server mode support is planned + for future implementation. + """ # TODO: server mode + # assert len(batch) == self.per_device_generation_batch_size assert self.vllm_mode == 'colocate' # Step 1: Wake up the engine if it's sleeping (vLLM colocate mode) if self.engine.inner_model_executor.is_sleeping: @@ -482,6 +538,19 @@ def _generate_completions(self, batch): return batch def preprocess_rollout_data(self, batch): + """ + Gather rollout trajectories across the vLLM tensor-parallel (TP) group. + + This method collect the full batch on every rank, then flattens + the nested lists into a single list of samples. + + Args: + batch (list): List of rollout samples local to this TP rank. + + Returns: + list: Flattened list containing all rollout samples from every + rank in the TP group. + """ if self.vllm_tensor_parallel_size == 1: return batch @@ -497,6 +566,21 @@ def _rollout(self, batch) -> List[RolloutOutput]: return rollout_outputs def postprocess_rollout_data(self, batch, output): + """ + Post-process the raw vLLM generation outputs and merge them back into the + original input batch. + + Args: + batch (List[Dict[str, Any]]): + Original rollout samples. + output (List[RolloutOutput]): + outputs from vLLM from vLLM TP group + + Returns: + List[Dict[str, Any]]: + Updated samples with rollout results merged in. + """ + if self.vllm_tensor_parallel_size > 1: local_rank_in_group = torch.distributed.get_rank(group=self.vllm_tp_group) orig_size = len(output) // self.vllm_tensor_parallel_size @@ -522,9 +606,8 @@ def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], out if output.response_loss_mask: input_data['response_loss_mask'] = output.response_loss_mask else: - if not self.multi_turn_scheduler: - # for single turn, skip tokenizer response - input_data['response_token_ids'] = output.response.choices[0].token_ids + # for single turn, skip tokenizer response + input_data['response_token_ids'] = output.response.choices[0].token_ids # Step 3: Attach rollout extra info if output.rollout_infos: @@ -547,7 +630,7 @@ def _get_request_config(self) -> RequestConfig: # otherwise, the program may hang. # 2. Ensure that the seed for vLLM Engines across different TP groups is different; # otherwise, identical completions will be generated. - batch_size = self.per_device_rollout_batch_size + batch_size = self.per_device_generation_batch_size batch_size *= self.vllm_tensor_parallel_size # Since the TP (Tensor Parallelism) group gathers the inputs, # multiply the batch size by the TP parallel size. @@ -582,7 +665,7 @@ def _compute_rewards_per_func(self, batch: DataType) -> torch.Tensor: zip(self.reward_funcs, self.reward_model_plugins, self.reward_func_names)): with profiling_context(self, reward_func_name): # reward model - reward_kwargs = {} # TODO: step info + reward_kwargs = {} # TODO: training step info if isinstance(reward_func, nn.Module): output_reward_func = reward_model_plugin(inputs=batch, **reward_kwargs) # reward function @@ -633,17 +716,15 @@ def maybe_normalize_advantages(advantages: torch.Tensor, rewards_std: torch.Tens return advantages def _maybe_compute_logps(self, batch: DataType) -> DataType: - # encode first - template = self.template - # get model forward kwargs from batch - batch = self._maybe_replace_response_token(batch) - with self._disable_maxlength_template_context(template): - batch_encoded_inputs = [template.encode(data, return_length=True) for data in batch] - batch_encoded_inputs = to_device(template.data_collator(batch_encoded_inputs), self.device) - labels = batch_encoded_inputs.pop('labels') - logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item() - batch_encoded_inputs['logits_to_keep'] = logits_to_keep - + # TODO: entropy + if self.beta != 0.0: + with torch.no_grad(), self.null_ref_context() as ref_model: + batch['ref_per_token_logps'] = self.model_forward( + ref_model, make_batch_generator(batch, self.micro_batch_size), no_grad=True)['logps'] + + if not self.on_policy: + batch['old_per_token_logps'] = self.model_forward( + self.unwrapped_model, make_batch_generator(batch, self.micro_batch_size), no_grad=True)['logps'] return batch @contextmanager @@ -667,3 +748,7 @@ def _maybe_replace_response_token(self, batch): data['messages'] = replace_assistant_response_with_ids(data['messages'], data['response_token_ids'], loss_mask) return batch + + @property + def on_policy(self): + return self.steps_per_generation == 1 diff --git a/swift/megatron/trainers/rlhf_base.py b/swift/megatron/trainers/rlhf_base.py index daa5d752e9..d1d0b600f0 100644 --- a/swift/megatron/trainers/rlhf_base.py +++ b/swift/megatron/trainers/rlhf_base.py @@ -63,13 +63,15 @@ def _forward_step_helper(model, inputs): return output_tensor - def ref_forward(self, ref_model, data_iterator): + def model_forward(self, model, data_iterator, no_grad=True): + # used to calculate model forward (logps) with self.stimer(bdata=True): data = get_batch(data_iterator) data.pop('loss_scale', None) labels = data.get('labels') - with torch.no_grad(): - output_tensor = self._forward_step_helper(ref_model, data) + context = torch.no_grad() if no_grad else nullcontext() + with context: + output_tensor = self._forward_step_helper(model, data) data['logps'] = None if labels is None else self.get_logps(output_tensor, labels, data['packed_seq_params']) return data diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 5f5c62b6ad..eaf13f4fe1 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import time from contextlib import contextmanager -from typing import Any, Dict +from typing import Any, Dict, List import torch from megatron.core import mpu @@ -170,3 +170,26 @@ def profiling_context(trainer, name: str): wandb_writer.log(profiling_metrics) # TODO: add swanlab support + + +def gather_tensor_dict(tensors: Dict[str, torch.Tensor], group): + if not isinstance(tensors, dict): + raise ValueError(f'Expected a dictionary, got {type(tensors)}') + size = torch.distributed.get_world_size(group=group) + + output = {} + sorted_keys = sorted(tensors.keys()) + for key in sorted_keys: + val = tensors[key] + output[key] = [torch.empty_like(val) for _ in range(size)] + torch.distributed.all_gather(val, val, group=group, async_op=False) + output[key] = torch.cat(output[key], dim=0) + + return output + + +def make_batch_generator(batch: List[Dict[str, Any]], batch_size: int): + assert batch_size > 0, 'batch_size must be positive' + assert len(batch) % batch_size == 0, 'batch length must be a multiple of batch_size' + for i in range(0, len(batch), batch_size): + yield batch[i:i + batch_size] From eebdd47cddc4c951ca3d7897680a026a7a9d61e5 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 24 Sep 2025 16:21:54 +0800 Subject: [PATCH 12/83] wip --- swift/megatron/train/rlhf.py | 3 +- swift/megatron/trainers/grpo_trainer.py | 94 +++++++++++++++---------- swift/megatron/trainers/rlhf_base.py | 33 +++++++++ swift/megatron/trainers/utils.py | 30 ++++++-- 4 files changed, 116 insertions(+), 44 deletions(-) diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index e98dd3c2c8..8152beac5f 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -26,7 +26,8 @@ def prepare_trainer(self): def _prepare_template(self) -> None: super()._prepare_template() - self.template.set_mode('rlhf') + model_mapping = {'grpo': 'train'} + self.template.set_mode(model_mapping.get(self.args.rlhf_type, 'rlhf')) def _get_data_collator(self): if self.args.rlhf_type == 'grpo': diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 1ce614ce90..e67ebbfa0a 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn -from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed +from accelerate.utils import broadcast_object_list, gather, is_peft_model, set_seed from megatron.core import mpu from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank from megatron.core.num_microbatches_calculator import get_num_microbatches @@ -33,7 +33,7 @@ from ..argument import MegatronRLHFArguments from .rlhf_base import MegatronRLHFTrainer from .trainer import MegatronTrainer -from .utils import gather_tensor_dict, get_batch, make_batch_generator, profiling_context +from .utils import gather_dict, gather_object, get_batch, make_batch_generator, profiling_context try: from mbridge import AutoBridge @@ -62,7 +62,6 @@ def __init__(self, args: MegatronRLHFArguments, template: Template): self._init_grpo_params() self._prepare_rewards() self._prepare_rollout_engine() - # debug: use mbridge to convert mcore to hf self.bridge = None @@ -127,12 +126,13 @@ def _init_grpo_params(self): self.top_entropy_quantile = args.top_entropy_quantile self.importance_sampling_level = args.importance_sampling_level self.enable_offload = False - # batch size + # batch size (completion-level) self.generation_batch_size = args.generation_batch_size self.steps_per_generation = args.steps_per_generation self.global_batch_size = args.global_batch_size self.micro_batch_size = args.micro_batch_size self.per_device_generation_batch_size = args.per_device_generation_batch_size + # sampling params self.request_config = RequestConfig( n=1, @@ -176,11 +176,25 @@ def _prepare_rollout_engine(self): if self.args.sleep_level > 0: self.engine.engine.sleep(self.args.sleep_level) + self._init_rollout_group() + + def _init_rollout_group(self): + args = self.args + model_size = args.tensor_model_parallel_size * args.pipeline_model_parallel_size * args.context_parallel_size + # each model share the rollout group (gather) + rollout_groups = [list(range(i, i + model_size)) for i in range(0, self.world_size, model_size)] + + for group_ranks in rollout_groups: + if self.process_index in group_ranks: + self.rollout_group = torch.distributed.new_group(ranks=group_ranks) + break + def prepare_vllm(self): from swift.llm.infer.infer_engine import GRPOVllmEngine args = self.args max_num_seqs = self.per_device_generation_batch_size * self.vllm_tensor_parallel_size * self.num_generations - + vllm_template = copy(self.template) + vllm_template.padding_free = False engine = GRPOVllmEngine( self.hf_model_dir, args.torch_dtype, @@ -197,7 +211,7 @@ def prepare_vllm(self): seed=self.process_index // self.vllm_tensor_parallel_size, disable_cascade_attn=self.args.vllm_disable_cascade_attn, load_format='dummy', - template=self.template, + template=vllm_template, distributed_executor_backend='external_launcher', ) if self.vllm_tensor_parallel_size > 1: @@ -211,6 +225,8 @@ def _prepare_rewards(self): reward_funcs = args.reward_funcs if not isinstance(reward_funcs, list): reward_funcs = [reward_funcs] + + # initilize reward functions if reward_funcs: for i, reward_func in enumerate(reward_funcs): if reward_func in orms: @@ -225,6 +241,8 @@ def _prepare_rewards(self): reward_funcs[i] = reward_func_class(**reward_func_kwargs) elif not callable(reward_func): raise ValueError(f'reward_function {reward_func} is not implemented in swift.plugin') + + # get reward name for logging self.reward_funcs = reward_funcs self.reward_func_names = [] for reward_func in reward_funcs: @@ -234,6 +252,7 @@ def _prepare_rewards(self): reward_func_name = reward_func.__class__.__name__ self.reward_func_names.append(reward_func_name) + # set reward weights if args.reward_weights is not None: if len(args.reward_weights) != len(reward_funcs): raise ValueError(f'Number of reward weights ({len(args.reward_weights)}) must match number of reward ' @@ -242,6 +261,11 @@ def _prepare_rewards(self): else: self.reward_weights = torch.ones(len(self.reward_func_names), dtype=torch.float32).to(self.device) + # TODO: reward models + self.reward_model_plugins = [None] * len(self.reward_funcs) + + assert self.reward_funcs, 'reward_funcs is not set' + def _move_model_to_vllm(self): # TODO: LoRA, server if self.bridge is None: @@ -428,15 +452,12 @@ def get_local_rollout_batch(batch): # repeat num_generations times global_rollout_batch = [item for item in batch for _ in range(self.num_generations)] # get local rollout data - # TODO: check do we should set with_context_parallel? debug with CP > 1 - data_parallel_size = mpu.get_data_parallel_world_size() - - dp_local_rank = self.process_index % data_parallel_size - dp_group_size = self.world_size // data_parallel_size - assert dp_group_size * self.per_device_generation_batch_size * self.num_generations == len( + rollout_rank = torch.distributed.get_rank(group=self.rollout_group) + rollout_group_size = torch.distributed.get_world_size(group=self.rollout_group) + assert rollout_group_size * self.per_device_generation_batch_size * self.num_generations == len( global_rollout_batch) per_device_batch_size = self.per_device_generation_batch_size * self.num_generations - data_slice = slice(dp_local_rank * per_device_batch_size, (dp_local_rank + 1) * per_device_batch_size) + data_slice = slice(rollout_rank * per_device_batch_size, (rollout_rank + 1) * per_device_batch_size) rollout_batch = global_rollout_batch[data_slice] return rollout_batch @@ -451,51 +472,49 @@ def get_local_rollout_batch(batch): advantages = self._compute_advantages(rollout_batch, rewards_per_func) - def _get_encoded_batch(rollout_batch): + def _get_encoded_batch(rollout_batch, advantages): template = self.template encoded_batch = [template.encode(data, return_length=True) for data in rollout_batch] encoded_batch = to_device(template.data_collator(encoded_batch), self.device) - labels = encoded_batch.pop('labels') + labels = encoded_batch['labels'] logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item() if self.template.padding_free: position_ids = encoded_batch.get('text_position_ids') if position_ids is None: position_ids = encoded_batch.get('position_ids') - position_ids = position_ids.squeeze() - assert position_ids is not None + squeezed_position_ids = position_ids.squeeze() + assert squeezed_position_ids is not None lengths = torch.diff( - torch.cat([(position_ids == 0).nonzero(as_tuple=True)[0], - torch.tensor([len(position_ids)]).to(position_ids.device)])) - nonlocal advantages + torch.cat([(squeezed_position_ids == 0).nonzero(as_tuple=True)[0], + torch.tensor([len(squeezed_position_ids)]).to(squeezed_position_ids.device)])) advantages = torch.repeat_interleave(advantages, lengths) encoded_batch.update({ 'completion_mask': labels[:, -logits_to_keep:] != -100, 'truncated_mask': - torch.tensor([b['is_truncated'] for b in rollout_batch], dtype=torch.bool), + torch.tensor([b['is_truncated'] for b in rollout_batch], dtype=torch.bool, device=self.device), 'advantages': advantages, 'position_ids': position_ids # remove it: non-padding-free }) + return encoded_batch + # Step2: gather in DP group, model forward to get ref/old logps # prepare model forward kwargs - encoded_batches = [] # [self.steps_per_generation, ] - for _ in range(self.steps_per_generation): - encoded_batch = _get_encoded_batch(rollout_batch) - encoded_batches.append(encoded_batch) - - dp_group = mpu.get_data_parallel_group(with_context_parallel=True) - gathered_encoded_batches = [] # [self.steps_per_generation, ] - for encoded_batch in encoded_batches: - gathered_encoded_batch = gather_tensor_dict(encoded_batch, group=dp_group) - gathered_encoded_batch = self._maybe_compute_logps(gathered_encoded_batch) - gathered_encoded_batches.append(gathered_encoded_batch) + total_batch = gather_object(rollout_batch, group=self.rollout_group) + # len(g_batch) = dp_world_size * self.per_device_generation_batch_size * self.num_generations + mini_batch_data = [] + for idx in range(0, len(total_batch), self.micro_batch_size): + micro_batch_data = _get_encoded_batch(rollout_batch[idx:idx + self.micro_batch_size], + advantages[idx:idx + self.micro_batch_size]) + micro_batch_data = self._maybe_compute_logps(micro_batch_data) + mini_batch_data.append(micro_batch_data) - return gathered_encoded_batches + return mini_batch_data def _generate_completions(self, batch): """ @@ -658,14 +677,14 @@ def _score_completions(self, inputs: DataType) -> torch.Tensor: def _compute_rewards_per_func(self, batch: DataType) -> torch.Tensor: """Compute rewards using all reward functions""" - device = self.accelerator.device + device = self.device rewards_per_func = torch.zeros((len(batch), len(self.reward_funcs)), device=device) completions = [inp['messages'][-1]['content'] for inp in batch] + reward_kwargs = {} # TODO: training step info for i, (reward_func, reward_model_plugin, reward_func_name) in enumerate( zip(self.reward_funcs, self.reward_model_plugins, self.reward_func_names)): with profiling_context(self, reward_func_name): # reward model - reward_kwargs = {} # TODO: training step info if isinstance(reward_func, nn.Module): output_reward_func = reward_model_plugin(inputs=batch, **reward_kwargs) # reward function @@ -719,12 +738,11 @@ def _maybe_compute_logps(self, batch: DataType) -> DataType: # TODO: entropy if self.beta != 0.0: with torch.no_grad(), self.null_ref_context() as ref_model: - batch['ref_per_token_logps'] = self.model_forward( - ref_model, make_batch_generator(batch, self.micro_batch_size), no_grad=True)['logps'] + batch['ref_per_token_logps'] = self.model_forward(ref_model, iter([batch]), no_grad=True)['logps'] if not self.on_policy: batch['old_per_token_logps'] = self.model_forward( - self.unwrapped_model, make_batch_generator(batch, self.micro_batch_size), no_grad=True)['logps'] + self.unwrapped_model, iter([batch]), no_grad=True)['logps'] return batch @contextmanager diff --git a/swift/megatron/trainers/rlhf_base.py b/swift/megatron/trainers/rlhf_base.py index d1d0b600f0..563fd2fec9 100644 --- a/swift/megatron/trainers/rlhf_base.py +++ b/swift/megatron/trainers/rlhf_base.py @@ -20,6 +20,39 @@ class MegatronRLHFTrainer(MegatronTrainer): + @contextmanager + def _get_iters(self, train_dataset, val_dataset): + origin_initialize_megatron = training.initialize_megatron + + def initialize_megatron(*_args, **kwargs): + res = origin_initialize_megatron(*_args, **kwargs) + args = get_args() + data_parallel_size = mpu.get_data_parallel_world_size() + step_batch_size = args.micro_batch_size * data_parallel_size + if args.train_iters is None and args.max_epochs is not None: + if hasattr(train_dataset, '__len__'): + dataset_sample = len(train_dataset) // step_batch_size * step_batch_size * args.num_generations + args.train_iters = dataset_sample * args.max_epochs // args.global_batch_size + else: + raise ValueError( + 'You are using a streaming training dataset. Please explicitly specify `--train_iters`.') + if args.eval_iters < 0: + if val_dataset is None: + args.eval_iters = 0 + elif hasattr(val_dataset, '__len__'): + dataset_sample = len(val_dataset) // step_batch_size * step_batch_size + args.eval_iters = max(dataset_sample // args.global_batch_size, 1) + else: + raise ValueError( + 'You are using a streaming validation dataset. Please explicitly specify `--eval_iters`.') + return res + + training.initialize_megatron = initialize_megatron + try: + yield + finally: + training.initialize_megatron = origin_initialize_megatron + def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs): args = get_args() if args.train_type == 'full': diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index eaf13f4fe1..55ed65d075 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -1,9 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import time from contextlib import contextmanager -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import torch +from accelerate.utils import gather_object as hf_gather_object from megatron.core import mpu from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.utils import get_batch_on_this_cp_rank as mcore_get_batch_on_this_cp_rank @@ -172,7 +173,7 @@ def profiling_context(trainer, name: str): # TODO: add swanlab support -def gather_tensor_dict(tensors: Dict[str, torch.Tensor], group): +def gather_dict(tensors: Dict[str, torch.Tensor], group: torch.distributed.ProcessGroup): if not isinstance(tensors, dict): raise ValueError(f'Expected a dictionary, got {type(tensors)}') size = torch.distributed.get_world_size(group=group) @@ -181,13 +182,32 @@ def gather_tensor_dict(tensors: Dict[str, torch.Tensor], group): sorted_keys = sorted(tensors.keys()) for key in sorted_keys: val = tensors[key] - output[key] = [torch.empty_like(val) for _ in range(size)] - torch.distributed.all_gather(val, val, group=group, async_op=False) - output[key] = torch.cat(output[key], dim=0) + if isinstance(val, int): + # num_samples + output[key] = val + continue + elif isinstance(val, torch.Tensor): + output[key] = [torch.empty_like(val) for _ in range(size)] + torch.distributed.all_gather(output[key], val, group=group, async_op=False) + output[key] = torch.cat(output[key], dim=0) + else: + output[key] = [None for _ in range(size)] + torch.distributed.all_gather_object(output[key], val, group=group, async_op=False) + output[key] = [item for sublist in output[key] for item in sublist] return output +def gather_object(object: Any, group: Optional[torch.distributed.ProcessGroup] = None): + if group is None: + return hf_gather_object(object) + size = torch.distributed.get_world_size(group=group) + output_objects = [None for _ in range(size)] + torch.distributed.all_gather_object(output_objects, object) + # all_gather_object returns a list of lists, so we need to flatten it + return [x for y in output_objects for x in y] + + def make_batch_generator(batch: List[Dict[str, Any]], batch_size: int): assert batch_size > 0, 'batch_size must be positive' assert len(batch) % batch_size == 0, 'batch length must be a multiple of batch_size' From de6ecfe33922208ca378b5be3c99c2a86b99713f Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Sun, 28 Sep 2025 14:38:13 +0800 Subject: [PATCH 13/83] loss wip --- swift/megatron/argument/megatron_args.py | 43 +-- swift/megatron/trainers/grpo_trainer.py | 370 ++++++++--------------- swift/megatron/trainers/rlhf_base.py | 32 +- swift/megatron/trainers/utils.py | 11 + 4 files changed, 189 insertions(+), 267 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 61aaad1f01..25cc4b54a6 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -168,8 +168,33 @@ def _check_not_supported(): if self.num_iterations > 1: raise ValueError('num_iterations > 1 is not supported for Megatron-GRPO yet, please set it to 1.') + def _check_batch_params(): + assert self.micro_batch_size % self.num_generations == 0, \ + f'micro_batch_size ({self.micro_batch_size}) must be divisible' \ + f' by the number of generations ({self.num_generations})' + if self.generation_batch_size is None and self.steps_per_generation is None: + self.steps_per_generation = 1 + self.generation_batch_size = self.global_batch_size * self.steps_per_generation + elif self.generation_batch_size is not None and self.steps_per_generation is None: + # Just ensure the value is divisible by the global batch size + if self.generation_batch_size % self.global_batch_size != 0: + raise ValueError(f'generation_batch_size ({self.generation_batch_size}) ' + f'must be divisible by the global batch size ({self.global_batch_size}).') + self.steps_per_generation = self.generation_batch_size // self.global_batch_size + elif self.generation_batch_size is None and self.steps_per_generation is not None: + self.generation_batch_size = self.global_batch_size * self.steps_per_generation + else: + raise ValueError( + "'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time") + world_size = torch.distributed.get_world_size() + assert self.generation_batch_size % world_size == 0, \ + f'generation_batch_size ({self.generation_batch_size}) ' \ + f'must be divisible by the world size ({world_size})' + self.per_device_generation_batch_size = self.generation_batch_size // world_size + _init_external_vllm() _check_not_supported() + _check_batch_params() if self.async_generate or not self.use_vllm: self.sleep_level = 0 self.remove_unused_columns = False @@ -204,24 +229,6 @@ def _check_not_supported(): self.vllm_mode = 'colocate' logger.warning('set vllm_mode to `colocate` since vllm_server_host is not provided') - if self.generation_batch_size is None and self.steps_per_generation is None: - self.steps_per_generation = 1 - self.generation_batch_size = self.global_batch_size * self.steps_per_generation - elif self.generation_batch_size is not None and self.steps_per_generation is None: - # Just ensure the value is divisible by the global batch size - if self.generation_batch_size % self.global_batch_size != 0: - raise ValueError( - f'generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size ' - f'({self.global_batch_size}).') - self.steps_per_generation = self.generation_batch_size // self.global_batch_size - elif self.generation_batch_size is None and self.steps_per_generation is not None: - self.generation_batch_size = self.global_batch_size * self.steps_per_generation - else: - raise ValueError( - "'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time") - world_size = torch.distributed.get_world_size() - self.per_device_generation_batch_size = self.generation_batch_size // world_size - @dataclass class MegatronTunerMixin: diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index e67ebbfa0a..1b6ea01e4d 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -9,45 +9,25 @@ import torch import torch.nn as nn -from accelerate.utils import broadcast_object_list, gather, is_peft_model, set_seed from megatron.core import mpu -from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank -from megatron.core.num_microbatches_calculator import get_num_microbatches -from megatron.core.pipeline_parallel import get_forward_backward_func -from megatron.core.rerun_state_machine import get_rerun_state_machine -from megatron.training import get_args, get_model, get_timers, training -from megatron.training.checkpointing import load_checkpoint -from megatron.training.training import cuda_graph_capture, cuda_graph_set_manual_hooks -from megatron.training.utils import (logical_and_across_model_parallel_group, - reduce_max_stat_across_model_parallel_group, unwrap_model) -from torch.distributed.nn import all_reduce +from megatron.training import get_args, training from vllm.distributed import parallel_state as vllm_ps from swift.llm import RequestConfig, RowPreprocessor, Template, to_device from swift.llm.infer.protocol import RolloutOutput from swift.plugin import orms -from swift.trainers.rlhf_trainer import GRPOTrainer, VLLMClient from swift.trainers.rlhf_trainer.grpo_trainer import DataType from swift.trainers.rlhf_trainer.utils import replace_assistant_response_with_ids from swift.utils import get_current_device, get_logger, is_vllm_available, remove_response from ..argument import MegatronRLHFArguments from .rlhf_base import MegatronRLHFTrainer -from .trainer import MegatronTrainer -from .utils import gather_dict, gather_object, get_batch, make_batch_generator, profiling_context +from .utils import gather, gather_dict, gather_object, profiling_context try: from mbridge import AutoBridge except ImportError: pass -try: - from megatron.post_training.algos.distillation import ( - get_tensor_shapes_adjust_fn_for_distillation, ) - - has_nvidia_modelopt = True -except ImportError: - has_nvidia_modelopt = False - logger = get_logger() @@ -65,49 +45,6 @@ def __init__(self, args: MegatronRLHFArguments, template: Template): # debug: use mbridge to convert mcore to hf self.bridge = None - def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, labels: torch.Tensor, - packed_seq_params): - # TODO:GRPO policy loss - args: MegatronRLHFArguments = get_args() - num_samples = packed_seq_params.num_samples - - logps = self.get_logps(output_tensor, labels, packed_seq_params) - loss, chosen_rewards, rejected_rewards = self.dummy_dpo_trainer.dpo_loss( - logps[:num_samples], - logps[num_samples:], - ref_logps[:num_samples], - ref_logps[num_samples:], - ) - if args.rpo_alpha: - loss_mask = labels != -100 - num_tokens = packed_seq_params.cu_seqlens_q[num_samples] // args.context_parallel_size - loss_mask[:, num_tokens:] = 0 - nll_loss = torch.concat([torch.sum(output_tensor * loss_mask)[None], loss_mask.sum()[None]]) - if args.context_parallel_size > 1: - nll_loss = all_reduce(nll_loss, group=mpu.get_context_parallel_group()) - nll_loss = nll_loss[0] / nll_loss[1] - loss = loss + args.rpo_alpha * nll_loss - loss = loss.mean() - metric = { - 'loss': loss.clone().detach(), - 'logps/chosen': logps[:num_samples].mean(), - 'logps/rejected': logps[num_samples:].mean(), - 'rewards/chosen': chosen_rewards.mean(), - 'rewards/rejected': rejected_rewards.mean(), - 'rewards/accuracies': (chosen_rewards > rejected_rewards).float().mean(), - 'rewards/margins': (chosen_rewards - rejected_rewards).mean(), - } - if args.rpo_alpha: - metric['nll_loss'] = nll_loss.detach() - reporting_metric = loss.new_tensor(list(metric.values())) - torch.distributed.all_reduce( - reporting_metric, torch.distributed.ReduceOp.AVG, group=mpu.get_data_parallel_group()) - reporting_metric = {k: reporting_metric[i] for i, k in enumerate(metric.keys())} - # fix megatron-lm bug - # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.12.0/megatron/core/pipeline_parallel/schedules.py#L291 - loss = loss / mpu.get_context_parallel_world_size() - return loss, reporting_metric - def _init_grpo_params(self): args = self.args # distributed params @@ -192,7 +129,7 @@ def _init_rollout_group(self): def prepare_vllm(self): from swift.llm.infer.infer_engine import GRPOVllmEngine args = self.args - max_num_seqs = self.per_device_generation_batch_size * self.vllm_tensor_parallel_size * self.num_generations + max_num_seqs = self.per_device_generation_batch_size * self.vllm_tensor_parallel_size vllm_template = copy(self.template) vllm_template.padding_free = False engine = GRPOVllmEngine( @@ -219,6 +156,13 @@ def prepare_vllm(self): self._buffered_inputs = None return engine + def _move_model_to_vllm(self): + # TODO: LoRA, server + if self.bridge is None: + self.bridge = AutoBridge.from_pretrained(self.hf_model_dir) + per_tensor_params = self.bridge.export_weights([self.unwrapped_model]) + self.engine.inner_model.load_weights(per_tensor_params) # TODO: check tensor_model_parallel + def _prepare_rewards(self): # TODO: reward model args = self.args @@ -266,185 +210,112 @@ def _prepare_rewards(self): assert self.reward_funcs, 'reward_funcs is not set' - def _move_model_to_vllm(self): - # TODO: LoRA, server - if self.bridge is None: - self.bridge = AutoBridge.from_pretrained(self.hf_model_dir) - per_tensor_params = self.bridge.export_weights([self.unwrapped_model]) - self.engine.inner_model.load_weights(per_tensor_params) # TODO: check tensor_model_parallel - def forward_step(self, data_iterator, model): # train_batch_size # return: output_tensor, loss_func data = next(data_iterator) - ref_logps = data.pop('logps') + inputs = { + k: v + for k, v in inputs.items() if k not in + ['completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', 'truncated_mask'] + } + with self.stimer: - output_tensor = model(**data) - return output_tensor, partial( - self.loss_func, - ref_logps=ref_logps, - labels=data.get('labels'), - packed_seq_params=data.get('packed_seq_params')) + output_tensor = model(**inputs) + return output_tensor, partial(self.loss_func, data=data) + + def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): + advantages = data['advantages'] + labels = data['labels'] + completion_mask = data['completion_mask'] + packed_seq_params = data['packed_seq_params'] + + per_token_logps = self.get_logps(output_tensor, labels, packed_seq_params) + if self.beta != 0.0: + ref_per_token_logps = data.get('ref_per_token_logps') + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1) + + old_per_token_logps = ( + per_token_logps.detach() if data.get('old_per_token_logps') is None else data['old_per_token_logps']) + log_ratio = per_token_logps - old_per_token_logps + + if self.importance_sampling_level == 'token': + log_importance_weights = log_ratio + elif self.importance_sampling_level == 'sequence': + log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + log_importance_weights = log_importance_weights.unsqueeze(-1) + elif self.importance_sampling_level == 'sequence_token': + # GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)] + seq_level_log_weight = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + seq_level_log_weight = seq_level_log_weight.detach().unsqueeze(-1) # Stop gradient + log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight + else: + raise ValueError( + f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " + "and 'sequence'.") + + coef_1 = torch.exp(log_importance_weights) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + if self.args.delta is not None: + coef_1 = torch.clamp(coef_1, max=self.args.delta) + + if self.template.padding_free: + advantages = advantages[-coef_1.shape[1]:] + per_token_loss1 = coef_1 * advantages.unsqueeze(0) + per_token_loss2 = coef_2 * advantages.unsqueeze(0) + else: + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl + + if self.loss_type == 'grpo': + loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() + elif self.loss_type == 'bnpo': + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + elif self.loss_type == 'dr_grpo': + loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) + else: + raise ValueError(f'Unknown loss type: {self.loss_type}') - def _patch_megatron(self): - super()._patch_megatron() - self._origin_train_step = self.train_step + loss = loss.mean() + metric = { + 'loss': loss.clone().detach(), + } + reporting_metric = loss.new_tensor(list(metric.values())) + torch.distributed.all_reduce( + reporting_metric, torch.distributed.ReduceOp.AVG, group=mpu.get_data_parallel_group()) + reporting_metric = {k: reporting_metric[i] for i, k in enumerate(metric.keys())} + # fix megatron-lm bug + # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.12.0/megatron/core/pipeline_parallel/schedules.py#L291 + loss = loss / mpu.get_context_parallel_world_size() + return loss, reporting_metric def _replace_data_iterator(self, data_iterator): args = get_args() if args.iteration % self.steps_per_generation == 0: - # gradient_accumulation_steps - num_iters_per_step = args.global_batch_size // (args.micro_batch_size * mpu.get_data_parallel_world_size()) - # prepare generation batch data + # each rollout DP group will generate generation_batch_size / world_size completions + num_iters_per_step = self.generation_batch_size // mpu.get_data_parallel_world_size() + # completions will be repeated num_generations times after + num_iters_per_step = num_iters_per_step // self.num_generations rollout_batch = [] for _ in range(self.steps_per_generation): for _ in range(num_iters_per_step): rollout_batch.extend(next(data_iterator)) - self._buffered_inputs = self._generate_and_score_completions(rollout_batch) + micro_batch_data = self._generate_and_score_completions(rollout_batch) + num_mini_batch = self.global_batch_size // (self.micro_batch_size * mpu.get_data_parallel_world_size()) + mini_batch_data = [ + micro_batch_data[i:i + num_mini_batch] for i in range(0, len(micro_batch_data), num_mini_batch) + ] + assert len(mini_batch_data) == self.steps_per_generation + self._buffered_inputs = mini_batch_data + inputs = self._buffered_inputs[args.iteration % self.steps_per_generation] - return make_batch_generator(inputs, batch_size=self.micro_batch_size) - - def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): - # borrowed from Megatron-LM 0.13 - # get train_batch_size Rollout / ref/old logps / reward / advantage - # split to mini_batches (iter mini_batch) - data_iterator = self._replace_data_iterator(data_iterator) - - args: MegatronRLHFArguments = get_args() - timers = get_timers() - - # split to mini-batches - - # CUDA Graph capturing only executes once, when it's the first training iteration. - if args.curr_iteration == args.iteration and args.external_cuda_graph: - cuda_graph_capture(model, config, args) - - # Set grad to zero. - for model_chunk in model: - model_chunk.zero_grad_buffer() - optimizer.zero_grad() - - # Collect garbage and empty unused memory. - gc.collect() - torch.cuda.empty_cache() - - rerun_state_machine = get_rerun_state_machine() - while rerun_state_machine.should_run_forward_backward(data_iterator): - # Set grad to zero. - for model_chunk in model: - model_chunk.zero_grad_buffer() - optimizer.zero_grad() - - if has_nvidia_modelopt: - # [ModelOpt]: Pipeline-parallel Distillation stacks student and teacher tensors - adjust_tensor_shapes_fn = get_tensor_shapes_adjust_fn_for_distillation( - model, args.seq_length, args.micro_batch_size, args.decoder_seq_length) - else: - adjust_tensor_shapes_fn = None - - # Forward pass. - forward_backward_func = get_forward_backward_func() - losses_reduced = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=data_iterator, - model=model, - num_microbatches=get_num_microbatches(), - seq_length=args.seq_length, - micro_batch_size=args.micro_batch_size, - decoder_seq_length=args.decoder_seq_length, - forward_only=False, - adjust_tensor_shapes_fn=adjust_tensor_shapes_fn, - ) - should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit() - if should_exit: - return {}, True, should_checkpoint, should_exit, exit_code, None, None - - # Empty unused memory. - if args.empty_unused_memory_level >= 1: - torch.cuda.empty_cache() - - # Vision gradients. - if args.vision_pretraining and args.vision_pretraining_type == 'dino': - unwrapped_model = unwrap_model(model[0]) - unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) - - # Update parameters. - - timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time) - update_successful, grad_norm, num_zeros_in_grad = optimizer.step() - timers('optimizer').stop() - - # when freezing sub-models we may have a mixture of successful and unsucessful ranks, - # so we must gather across mp ranks - update_successful = logical_and_across_model_parallel_group(update_successful) - # grad_norm and num_zeros_in_grad will be None on ranks without trainable params, - # so we must gather across mp ranks - grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm) - if args.log_num_zeros_in_grad: - num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad) - - # Vision momentum. - if args.vision_pretraining and args.vision_pretraining_type == 'dino': - unwrapped_model = unwrap_model(model[0]) - unwrapped_model.update_momentum(args.curr_iteration) - - # Update learning rate. - if update_successful: - increment = get_num_microbatches() * args.micro_batch_size * args.data_parallel_size - opt_param_scheduler.step(increment=increment) - skipped_iter = 0 - else: - skipped_iter = 1 - - # Empty unused memory. - if args.empty_unused_memory_level >= 2: - torch.cuda.empty_cache() - - # Set the manual hooks when CUDA Graphs are enabled. - if args.curr_iteration == args.iteration and args.external_cuda_graph: - if args.use_distributed_optimizer and args.overlap_param_gather: - cuda_graph_set_manual_hooks(model) - - if mpu.is_pipeline_last_stage(ignore_virtual=True): - # Average loss across microbatches. - loss_reduced = {} - - for key in losses_reduced[0].keys(): - val = [x[key].view(-1) for x in losses_reduced] - if val[0].numel() == 2: - if args.sft: - # in mcore the normalization happens on micro batch instead of global - val = torch.vstack(val) - val = val[:, 0] / val[:, 1] - val = val.mean() - torch.distributed.all_reduce(val, group=mpu.get_data_parallel_group(with_context_parallel=True)) - val /= torch.distributed.get_world_size( - group=mpu.get_data_parallel_group(with_context_parallel=True)) - loss_reduced[key] = val - else: - # there is one dict per microbatch. in new reporting, we average - # over the total number of tokens across the global batch. - val = torch.vstack(val).sum(dim=0) - torch.distributed.all_reduce(val, group=mpu.get_data_parallel_group(with_context_parallel=True)) - loss_reduced[key] = val[0] / val[1] - elif val[0].numel() == 1: - # legacy behavior, we average over the number of microbatches - val = torch.cat(val).mean() - loss_reduced[key] = val - else: - raise ValueError(f'Invalid value shape: {val[0].shape} for key {key}') - return ( - loss_reduced, - skipped_iter, - should_checkpoint, - should_exit, - exit_code, - grad_norm, - num_zeros_in_grad, - ) - return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad + return iter(inputs) def _generate_and_score_completions(self, batch): # batch : same across DP groups @@ -454,9 +325,8 @@ def get_local_rollout_batch(batch): # get local rollout data rollout_rank = torch.distributed.get_rank(group=self.rollout_group) rollout_group_size = torch.distributed.get_world_size(group=self.rollout_group) - assert rollout_group_size * self.per_device_generation_batch_size * self.num_generations == len( - global_rollout_batch) - per_device_batch_size = self.per_device_generation_batch_size * self.num_generations + per_device_batch_size = self.per_device_generation_batch_size + assert rollout_group_size * per_device_batch_size == len(global_rollout_batch) data_slice = slice(rollout_rank * per_device_batch_size, (rollout_rank + 1) * per_device_batch_size) rollout_batch = global_rollout_batch[data_slice] return rollout_batch @@ -506,11 +376,11 @@ def _get_encoded_batch(rollout_batch, advantages): # Step2: gather in DP group, model forward to get ref/old logps # prepare model forward kwargs total_batch = gather_object(rollout_batch, group=self.rollout_group) - # len(g_batch) = dp_world_size * self.per_device_generation_batch_size * self.num_generations + total_advantages = gather(advantages, group=self.rollout_group) mini_batch_data = [] for idx in range(0, len(total_batch), self.micro_batch_size): - micro_batch_data = _get_encoded_batch(rollout_batch[idx:idx + self.micro_batch_size], - advantages[idx:idx + self.micro_batch_size]) + micro_batch_data = _get_encoded_batch(total_batch[idx:idx + self.micro_batch_size], + total_advantages[idx:idx + self.micro_batch_size]) micro_batch_data = self._maybe_compute_logps(micro_batch_data) mini_batch_data.append(micro_batch_data) @@ -524,18 +394,15 @@ def _generate_completions(self, batch): using the vLLM engine, and merges the results back into the original batch. Args: - batch: Rollout data assigned to the current process. Expected size is - per_device_generation_batch_size. + batch: Rollout data assigned to the current process. Returns: batch: The input batch with rollout completion results merged in. Note: - Currently only supports colocate mode. Server mode support is planned - for future implementation. + Currently only supports colocate mode. Server mode support is planned for future implementation. """ # TODO: server mode - # assert len(batch) == self.per_device_generation_batch_size assert self.vllm_mode == 'colocate' # Step 1: Wake up the engine if it's sleeping (vLLM colocate mode) if self.engine.inner_model_executor.is_sleeping: @@ -770,3 +637,28 @@ def _maybe_replace_response_token(self, batch): @property def on_policy(self): return self.steps_per_generation == 1 + + @contextmanager + def patch_megatron_data_collator(self, data_collator): + """ + Context manager that temporarily patches Megatron's data-loader factory so each + prompt-level micro-batch size equals (original micro-batch size // num_generations), + required by GRPO. Restores the original size and loader on exit. + """ + origin_build_pretraining_data_loader = training.build_pretraining_data_loader + + def build_pretraining_data_loader(*_args, **kwargs): + args = get_args() + org_micro_batch_size = args.micro_batch_size + args.micro_batch_size = org_micro_batch_size // self.num_generations + res = origin_build_pretraining_data_loader(*_args, **kwargs) + args.micro_batch_size = org_micro_batch_size + if res is not None and args.dataloader_type != 'external': + res.collate_fn = data_collator + return res + + training.build_pretraining_data_loader = build_pretraining_data_loader + try: + yield + finally: + training.build_pretraining_data_loader = origin_build_pretraining_data_loader diff --git a/swift/megatron/trainers/rlhf_base.py b/swift/megatron/trainers/rlhf_base.py index 563fd2fec9..c9a9934a49 100644 --- a/swift/megatron/trainers/rlhf_base.py +++ b/swift/megatron/trainers/rlhf_base.py @@ -105,24 +105,36 @@ def model_forward(self, model, data_iterator, no_grad=True): context = torch.no_grad() if no_grad else nullcontext() with context: output_tensor = self._forward_step_helper(model, data) - data['logps'] = None if labels is None else self.get_logps(output_tensor, labels, data['packed_seq_params']) + data['logps'] = None if labels is None else self.get_logps( + output_tensor, labels, data['packed_seq_params'], per_token=True) return data @staticmethod - def get_logps(output_tensor, labels, packed_seq_params): + def get_logps(output_tensor, labels, packed_seq_params, per_token: bool = False): args = get_args() per_token_logps = -output_tensor loss_mask = labels != -100 per_token_logps = per_token_logps * loss_mask num_samples = packed_seq_params.num_samples - cu_seqlens = packed_seq_params.cu_seqlens_q[:num_samples * 2 + 1] // args.context_parallel_size - all_logps = per_token_logps.new_zeros((num_samples * 2, )) - for i in range(num_samples * 2): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - all_logps[i] = per_token_logps[:, start:end].sum() - if args.context_parallel_size > 1: - all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) - return all_logps + if args.rlhf_type == 'dpo': + total_samples = num_samples * 2 + elif args.rlhf_type in 'grpo': + total_samples = num_samples + + cu_seqlens = packed_seq_params.cu_seqlens_q[:total_samples + 1] // args.context_parallel_size + + if per_token: + if args.context_parallel_size > 1: + per_token_logps = all_reduce(per_token_logps, group=mpu.get_context_parallel_group()) + return per_token_logps + else: + all_logps = per_token_logps.new_zeros((total_samples, )) + for i in range(total_samples): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + all_logps[i] = per_token_logps[:, start:end].sum() + if args.context_parallel_size > 1: + all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) + return all_logps @contextmanager def null_ref_context(self): diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 55ed65d075..7ffc33ff27 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional import torch +from accelerate.utils import gather as hf_gather from accelerate.utils import gather_object as hf_gather_object from megatron.core import mpu from megatron.core.packed_seq_params import PackedSeqParams @@ -198,6 +199,16 @@ def gather_dict(tensors: Dict[str, torch.Tensor], group: torch.distributed.Proce return output +def gather(tensor, group: Optional[torch.distributed.ProcessGroup] = None): + if group is None: + return hf_gather(tensor) + size = torch.distributed.get_world_size(group=group) + output = [torch.empty_like(tensor) for _ in range(size)] + torch.distributed.all_gather(output, tensor, group=group, async_op=False) + + return torch.cat(output, dim=0) + + def gather_object(object: Any, group: Optional[torch.distributed.ProcessGroup] = None): if group is None: return hf_gather_object(object) From 4569e5445f0d7e95148da21ff35624a8d487d097 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Sun, 28 Sep 2025 21:12:11 +0800 Subject: [PATCH 14/83] fix repeat n --- swift/megatron/argument/megatron_args.py | 5 +++ swift/megatron/trainers/dpo_trainer.py | 16 -------- swift/megatron/trainers/grpo_trainer.py | 51 ++++++++++++++---------- swift/megatron/trainers/rlhf_base.py | 4 +- swift/megatron/trainers/utils.py | 22 ++++++---- 5 files changed, 50 insertions(+), 48 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 25cc4b54a6..861ecd8f27 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -195,6 +195,11 @@ def _check_batch_params(): _init_external_vllm() _check_not_supported() _check_batch_params() + # default loss_type if no loss_type is provided + if self.loss_type == 'sigmoid': + self.loss_type = 'grpo' + assert self.loss_type in ['grpo', 'bnpo', 'dr_grpo'], \ + f'loss_type must be one of [grpo, bnpo, dr_grpo], but got {self.loss_type}' if self.async_generate or not self.use_vllm: self.sleep_level = 0 self.remove_unused_columns = False diff --git a/swift/megatron/trainers/dpo_trainer.py b/swift/megatron/trainers/dpo_trainer.py index 72d57e50b4..bc212de2b9 100644 --- a/swift/megatron/trainers/dpo_trainer.py +++ b/swift/megatron/trainers/dpo_trainer.py @@ -84,22 +84,6 @@ def _forward_step_helper(model, inputs): return output_tensor - @staticmethod - def get_logps(output_tensor, labels, packed_seq_params): - args = get_args() - per_token_logps = -output_tensor - loss_mask = labels != -100 - per_token_logps = per_token_logps * loss_mask - num_samples = packed_seq_params.num_samples - cu_seqlens = packed_seq_params.cu_seqlens_q[:num_samples * 2 + 1] // args.context_parallel_size - all_logps = per_token_logps.new_zeros((num_samples * 2, )) - for i in range(num_samples * 2): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - all_logps[i] = per_token_logps[:, start:end].sum() - if args.context_parallel_size > 1: - all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) - return all_logps - def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, labels: torch.Tensor, packed_seq_params): args = get_args() diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 1b6ea01e4d..60946a5e13 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -3,7 +3,7 @@ import inspect from collections import namedtuple from contextlib import contextmanager, nullcontext -from copy import copy +from copy import copy, deepcopy from functools import partial from typing import Any, Dict, List, Union @@ -21,7 +21,7 @@ from swift.utils import get_current_device, get_logger, is_vllm_available, remove_response from ..argument import MegatronRLHFArguments from .rlhf_base import MegatronRLHFTrainer -from .utils import gather, gather_dict, gather_object, profiling_context +from .utils import gather, gather_dict, gather_object, process_packed_seq_params, profiling_context try: from mbridge import AutoBridge @@ -217,7 +217,7 @@ def forward_step(self, data_iterator, model): data = next(data_iterator) inputs = { k: v - for k, v in inputs.items() if k not in + for k, v in data.items() if k not in ['completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', 'truncated_mask'] } @@ -231,7 +231,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): completion_mask = data['completion_mask'] packed_seq_params = data['packed_seq_params'] - per_token_logps = self.get_logps(output_tensor, labels, packed_seq_params) + per_token_logps = self.get_logps(output_tensor, labels, packed_seq_params, per_token=True) if self.beta != 0.0: ref_per_token_logps = data.get('ref_per_token_logps') per_token_kl = ( @@ -321,7 +321,7 @@ def _generate_and_score_completions(self, batch): # batch : same across DP groups def get_local_rollout_batch(batch): # repeat num_generations times - global_rollout_batch = [item for item in batch for _ in range(self.num_generations)] + global_rollout_batch = [deepcopy(item) for item in batch for _ in range(self.num_generations)] # get local rollout data rollout_rank = torch.distributed.get_rank(group=self.rollout_group) rollout_group_size = torch.distributed.get_world_size(group=self.rollout_group) @@ -331,8 +331,7 @@ def get_local_rollout_batch(batch): rollout_batch = global_rollout_batch[data_slice] return rollout_batch - # Step1: get local rollout data in DP group - # rollout_batch : repeat num_generations times, get current process rollout data + # Step1: Rollout / Reward / Advantage rollout_batch = get_local_rollout_batch(batch) @@ -347,7 +346,8 @@ def _get_encoded_batch(rollout_batch, advantages): encoded_batch = [template.encode(data, return_length=True) for data in rollout_batch] encoded_batch = to_device(template.data_collator(encoded_batch), self.device) labels = encoded_batch['labels'] - logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item() + # TODO: logits_to_keep + # logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item() if self.template.padding_free: position_ids = encoded_batch.get('text_position_ids') if position_ids is None: @@ -362,7 +362,7 @@ def _get_encoded_batch(rollout_batch, advantages): encoded_batch.update({ 'completion_mask': - labels[:, -logits_to_keep:] != -100, + labels != -100, 'truncated_mask': torch.tensor([b['is_truncated'] for b in rollout_batch], dtype=torch.bool, device=self.device), 'advantages': @@ -373,8 +373,7 @@ def _get_encoded_batch(rollout_batch, advantages): return encoded_batch - # Step2: gather in DP group, model forward to get ref/old logps - # prepare model forward kwargs + # Step2: ref/old logps total_batch = gather_object(rollout_batch, group=self.rollout_group) total_advantages = gather(advantages, group=self.rollout_group) mini_batch_data = [] @@ -419,8 +418,8 @@ def _generate_completions(self, batch): self.engine.engine.wake_up(tags=['kv_cache']) batch = self.preprocess_rollout_data(batch) - output: List[RolloutOutput] = self._rollout(batch) - batch = self.postprocess_rollout_data(batch, output) + outputs: List[RolloutOutput] = self._rollout(batch) + batch = self.postprocess_rollout_data(batch, outputs) return batch def preprocess_rollout_data(self, batch): @@ -451,7 +450,7 @@ def _rollout(self, batch) -> List[RolloutOutput]: rollout_outputs = self._colocate_rollout(batch, request_config) return rollout_outputs - def postprocess_rollout_data(self, batch, output): + def postprocess_rollout_data(self, batch, outputs): """ Post-process the raw vLLM generation outputs and merge them back into the original input batch. @@ -459,7 +458,7 @@ def postprocess_rollout_data(self, batch, output): Args: batch (List[Dict[str, Any]]): Original rollout samples. - output (List[RolloutOutput]): + outputs (List[RolloutOutput]): outputs from vLLM from vLLM TP group Returns: @@ -469,9 +468,9 @@ def postprocess_rollout_data(self, batch, output): if self.vllm_tensor_parallel_size > 1: local_rank_in_group = torch.distributed.get_rank(group=self.vllm_tp_group) - orig_size = len(output) // self.vllm_tensor_parallel_size + orig_size = len(outputs) // self.vllm_tensor_parallel_size tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - output = output[tp_slice] + outputs = outputs[tp_slice] def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], output: RolloutOutput): response = output.response @@ -485,7 +484,6 @@ def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], out messages = input_data['messages'] remove_response(messages) messages.append({'role': 'assistant', 'content': choice.message.content}) - # Step 2: Add token IDs and loss mask if output.response_token_ids: input_data['response_token_ids'] = output.response_token_ids @@ -505,8 +503,8 @@ def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], out return input_data - assert len(batch) == len(output) - return [merge_output_input_data(input_data, output) for input_data, output in zip(batch, output)] + assert len(batch) == len(outputs) + return [merge_output_input_data(input_data, output) for input_data, output in zip(batch, outputs)] def _get_request_config(self) -> RequestConfig: request_config = copy(self.request_config) @@ -526,6 +524,10 @@ def _get_request_config(self) -> RequestConfig: def _colocate_rollout(self, batch, request_config: RequestConfig): outputs: List[RolloutOutput] = self.engine.infer(infer_requests=batch, request_config=request_config) + completions = [output.response.choices[0].message.content for output in outputs] + if self.process_index == 0: + for completion in completions: + print(completion) return outputs def _score_completions(self, inputs: DataType) -> torch.Tensor: @@ -605,11 +607,16 @@ def _maybe_compute_logps(self, batch: DataType) -> DataType: # TODO: entropy if self.beta != 0.0: with torch.no_grad(), self.null_ref_context() as ref_model: - batch['ref_per_token_logps'] = self.model_forward(ref_model, iter([batch]), no_grad=True)['logps'] + batch['ref_per_token_logps'] = self.model_forward( + ref_model, iter([batch]), no_grad=True, per_token=True)['logps'] if not self.on_policy: batch['old_per_token_logps'] = self.model_forward( - self.unwrapped_model, iter([batch]), no_grad=True)['logps'] + self.unwrapped_model, iter([batch]), no_grad=True, per_token=True)['logps'] + + # get packed_seq_params, from get_batch func + batch = process_packed_seq_params(batch) + return batch @contextmanager diff --git a/swift/megatron/trainers/rlhf_base.py b/swift/megatron/trainers/rlhf_base.py index c9a9934a49..2a7064081d 100644 --- a/swift/megatron/trainers/rlhf_base.py +++ b/swift/megatron/trainers/rlhf_base.py @@ -96,7 +96,7 @@ def _forward_step_helper(model, inputs): return output_tensor - def model_forward(self, model, data_iterator, no_grad=True): + def model_forward(self, model, data_iterator, no_grad=True, per_token=False): # used to calculate model forward (logps) with self.stimer(bdata=True): data = get_batch(data_iterator) @@ -106,7 +106,7 @@ def model_forward(self, model, data_iterator, no_grad=True): with context: output_tensor = self._forward_step_helper(model, data) data['logps'] = None if labels is None else self.get_logps( - output_tensor, labels, data['packed_seq_params'], per_token=True) + output_tensor, labels, data['packed_seq_params'], per_token=per_token) return data @staticmethod diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 7ffc33ff27..26655c8ce6 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -68,6 +68,18 @@ def get_packed_seq_params(position_ids: torch.Tensor) -> PackedSeqParams: qkv_format='thd') +def process_packed_seq_params(batch: Dict[str, Any]) -> int: + args = get_args() + num_samples = batch.pop('num_samples') + text_position_ids = batch.pop('text_position_ids', None) + if text_position_ids is None: + text_position_ids = batch.get('position_ids') + if args.padding_free and text_position_ids is not None: + batch['packed_seq_params'] = get_packed_seq_params(text_position_ids) + batch['packed_seq_params'].num_samples = num_samples + return batch + + def _split_tokens(tokens, cu_seqlens): assert tokens.shape[-2] == 1, f'tokens.shape: {tokens.shape}' # [..., 1, L] new_tokens = [] @@ -146,14 +158,8 @@ def get_batch(data_iterator): """Generate a batch.""" # get batches based on the TP rank you are on batch = get_batch_on_this_tp_rank(data_iterator) - args = get_args() - num_samples = batch.pop('num_samples') - text_position_ids = batch.pop('text_position_ids', None) - if text_position_ids is None: - text_position_ids = batch.get('position_ids') - if args.padding_free and text_position_ids is not None: - batch['packed_seq_params'] = get_packed_seq_params(text_position_ids) - batch['packed_seq_params'].num_samples = num_samples + # process batch for packed sequence support + batch = process_packed_seq_params(batch) # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) return batch From 9cb84e30a1c7915ea281bdd5f7a37b82d82edb15 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 29 Sep 2025 17:22:27 +0800 Subject: [PATCH 15/83] fix padding to multiple of tp_size --- docs/source/_extra/.htaccess | 55 ------ docs/source/_extra/url_aliases.py | 149 --------------- docs/source_en/_extra/.htaccess | 55 ------ swift/megatron/argument/megatron_args.py | 6 +- swift/megatron/model/gpt_model.py | 74 -------- swift/megatron/trainers/grpo_trainer.py | 228 ++++++++++++++--------- swift/megatron/trainers/utils.py | 34 +--- 7 files changed, 139 insertions(+), 462 deletions(-) delete mode 100644 docs/source/_extra/.htaccess delete mode 100644 docs/source/_extra/url_aliases.py delete mode 100644 docs/source_en/_extra/.htaccess diff --git a/docs/source/_extra/.htaccess b/docs/source/_extra/.htaccess deleted file mode 100644 index bc2ea7a102..0000000000 --- a/docs/source/_extra/.htaccess +++ /dev/null @@ -1,55 +0,0 @@ -# URL重写规则 - 用于缩短readthedocs链接 -# 这个文件将被复制到readthedocs的根目录,用于URL重定向 - -RewriteEngine On - -# 常见问题整理 - 缩短URL -RewriteRule ^faq/?$ /zh-cn/latest/Instruction/常见问题整理.html [R=301,L] -RewriteRule ^faq/(.*)$ /zh-cn/latest/Instruction/常见问题整理.html#$1 [R=301,L] - -# 支持的模型和数据集 -RewriteRule ^models/?$ /zh-cn/latest/Instruction/支持的模型和数据集.html [R=301,L] -RewriteRule ^models/(.*)$ /zh-cn/latest/Instruction/支持的模型和数据集.html#$1 [R=301,L] - -# 命令行参数 -RewriteRule ^params/?$ /zh-cn/latest/Instruction/命令行参数.html [R=301,L] -RewriteRule ^params/(.*)$ /zh-cn/latest/Instruction/命令行参数.html#$1 [R=301,L] - -# 自定义数据集 -RewriteRule ^custom-dataset/?$ /zh-cn/latest/Customization/自定义数据集.html [R=301,L] -RewriteRule ^custom-dataset/(.*)$ /zh-cn/latest/Customization/自定义数据集.html#$1 [R=301,L] - -# 推理和部署 -RewriteRule ^deploy/?$ /zh-cn/latest/Instruction/推理和部署.html [R=301,L] -RewriteRule ^deploy/(.*)$ /zh-cn/latest/Instruction/推理和部署.html#$1 [R=301,L] - -# 评测 -RewriteRule ^eval/?$ /zh-cn/latest/Instruction/评测.html [R=301,L] -RewriteRule ^eval/(.*)$ /zh-cn/latest/Instruction/评测.html#$1 [R=301,L] - -# 预训练与微调 -RewriteRule ^training/?$ /zh-cn/latest/Instruction/预训练与微调.html [R=301,L] -RewriteRule ^training/(.*)$ /zh-cn/latest/Instruction/预训练与微调.html#$1 [R=301,L] - -# SWIFT安装 -RewriteRule ^install/?$ /zh-cn/latest/GetStarted/SWIFT安装.html [R=301,L] -RewriteRule ^install/(.*)$ /zh-cn/latest/GetStarted/SWIFT安装.html#$1 [R=301,L] - -# 快速开始 -RewriteRule ^quickstart/?$ /zh-cn/latest/GetStarted/快速开始.html [R=301,L] -RewriteRule ^quickstart/(.*)$ /zh-cn/latest/GetStarted/快速开始.html#$1 [R=301,L] - -# 多模态 -RewriteRule ^multimodal/?$ /zh-cn/latest/Multi-Modal/index.html [R=301,L] -RewriteRule ^multimodal/(.*)$ /zh-cn/latest/Multi-Modal/$1 [R=301,L] - -# 强化学习 -RewriteRule ^rl/?$ /zh-cn/latest/RLHF/index.html [R=301,L] -RewriteRule ^rl/(.*)$ /zh-cn/latest/RLHF/$1 [R=301,L] - -# 自定义 -RewriteRule ^custom/?$ /zh-cn/latest/Customization/index.html [R=301,L] -RewriteRule ^custom/(.*)$ /zh-cn/latest/Customization/$1 [R=301,L] - -# 如果以上规则都不匹配,则重定向到主页 -RewriteRule ^.*$ /zh-cn/latest/ [R=301,L] diff --git a/docs/source/_extra/url_aliases.py b/docs/source/_extra/url_aliases.py deleted file mode 100644 index 75faf906c2..0000000000 --- a/docs/source/_extra/url_aliases.py +++ /dev/null @@ -1,149 +0,0 @@ -# URL别名映射 - 用于在文档中生成短链接 -# 这个文件定义了常用的URL别名,可以在文档构建时使用 - -# 中文版URL别名 -ZH_URL_ALIASES = { - # 常见问题 - 'faq': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/常见问题整理.html', - '常见问题': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/常见问题整理.html', - - # 支持的模型和数据集 - 'models': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/支持的模型和数据集.html', - '模型列表': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/支持的模型和数据集.html', - '数据集列表': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/支持的模型和数据集.html', - - # 命令行参数 - 'params': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/命令行参数.html', - '命令行参数': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/命令行参数.html', - '参数说明': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/命令行参数.html', - - # 自定义数据集 - 'custom-dataset': 'https://swift.readthedocs.io/zh-cn/latest/Customization/自定义数据集.html', - '自定义数据集': 'https://swift.readthedocs.io/zh-cn/latest/Customization/自定义数据集.html', - - # 推理和部署 - 'deploy': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/推理和部署.html', - '推理部署': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/推理和部署.html', - '部署': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/推理和部署.html', - - # 评测 - 'eval': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/评测.html', - '评测': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/评测.html', - - # 预训练与微调 - 'training': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/预训练与微调.html', - '训练': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/预训练与微调.html', - '微调': 'https://swift.readthedocs.io/zh-cn/latest/Instruction/预训练与微调.html', - - # SWIFT安装 - 'install': 'https://swift.readthedocs.io/zh-cn/latest/GetStarted/SWIFT安装.html', - '安装': 'https://swift.readthedocs.io/zh-cn/latest/GetStarted/SWIFT安装.html', - - # 快速开始 - 'quickstart': 'https://swift.readthedocs.io/zh-cn/latest/GetStarted/快速开始.html', - '快速开始': 'https://swift.readthedocs.io/zh-cn/latest/GetStarted/快速开始.html', - - # 多模态 - 'multimodal': 'https://swift.readthedocs.io/zh-cn/latest/Multi-Modal/index.html', - '多模态': 'https://swift.readthedocs.io/zh-cn/latest/Multi-Modal/index.html', - - # 强化学习 - 'rl': 'https://swift.readthedocs.io/zh-cn/latest/RLHF/index.html', - '强化学习': 'https://swift.readthedocs.io/zh-cn/latest/RLHF/index.html', - 'RLHF': 'https://swift.readthedocs.io/zh-cn/latest/RLHF/index.html', - - # 自定义 - 'custom': 'https://swift.readthedocs.io/zh-cn/latest/Customization/index.html', - '自定义': 'https://swift.readthedocs.io/zh-cn/latest/Customization/index.html', -} - -# 英文版URL别名 -EN_URL_ALIASES = { - # Frequently Asked Questions - 'faq': 'https://swift.readthedocs.io/en/latest/Instruction/Frequently-asked-questions.html', - 'frequently-asked-questions': 'https://swift.readthedocs.io/en/latest/Instruction/Frequently-asked-questions.html', - - # Supported Models and Datasets - 'models': 'https://swift.readthedocs.io/en/latest/Instruction/Supported-models-and-datasets.html', - 'supported-models': 'https://swift.readthedocs.io/en/latest/Instruction/Supported-models-and-datasets.html', - 'datasets': 'https://swift.readthedocs.io/en/latest/Instruction/Supported-models-and-datasets.html', - - # Command Line Parameters - 'params': 'https://swift.readthedocs.io/en/latest/Instruction/Command-line-parameters.html', - 'command-line-parameters': 'https://swift.readthedocs.io/en/latest/Instruction/Command-line-parameters.html', - 'parameters': 'https://swift.readthedocs.io/en/latest/Instruction/Command-line-parameters.html', - - # Custom Dataset - 'custom-dataset': 'https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html', - 'custom-datasets': 'https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html', - - # Inference and Deployment - 'deploy': 'https://swift.readthedocs.io/en/latest/Instruction/Inference-and-deployment.html', - 'inference': 'https://swift.readthedocs.io/en/latest/Instruction/Inference-and-deployment.html', - 'deployment': 'https://swift.readthedocs.io/en/latest/Instruction/Inference-and-deployment.html', - - # Evaluation - 'eval': 'https://swift.readthedocs.io/en/latest/Instruction/Evaluation.html', - 'evaluation': 'https://swift.readthedocs.io/en/latest/Instruction/Evaluation.html', - - # Pre-training and Fine-tuning - 'training': 'https://swift.readthedocs.io/en/latest/Instruction/Pre-training-and-fine-tuning.html', - 'pre-training': 'https://swift.readthedocs.io/en/latest/Instruction/Pre-training-and-fine-tuning.html', - 'fine-tuning': 'https://swift.readthedocs.io/en/latest/Instruction/Pre-training-and-fine-tuning.html', - - # SWIFT Installation - 'install': 'https://swift.readthedocs.io/en/latest/GetStarted/SWIFT-installation.html', - 'installation': 'https://swift.readthedocs.io/en/latest/GetStarted/SWIFT-installation.html', - - # Quick Start - 'quickstart': 'https://swift.readthedocs.io/en/latest/GetStarted/Quick-start.html', - 'quick-start': 'https://swift.readthedocs.io/en/latest/GetStarted/Quick-start.html', - - # Multi-Modal - 'multimodal': 'https://swift.readthedocs.io/en/latest/Multi-Modal/index.html', - 'multi-modal': 'https://swift.readthedocs.io/en/latest/Multi-Modal/index.html', - - # Reinforcement Learning - 'rl': 'https://swift.readthedocs.io/en/latest/RLHF/index.html', - 'rlhf': 'https://swift.readthedocs.io/en/latest/RLHF/index.html', - 'reinforcement-learning': 'https://swift.readthedocs.io/en/latest/RLHF/index.html', - - # Customization - 'custom': 'https://swift.readthedocs.io/en/latest/Customization/index.html', - 'customization': 'https://swift.readthedocs.io/en/latest/Customization/index.html', -} - - -def get_url_alias(alias_key, language='zh'): - """ - 获取URL别名对应的完整URL - - Args: - alias_key (str): 别名键 - language (str): 语言,'zh' 或 'en' - - Returns: - str: 完整的URL,如果找不到别名则返回None - """ - if language == 'zh': - return ZH_URL_ALIASES.get(alias_key) - elif language == 'en': - return EN_URL_ALIASES.get(alias_key) - return None - - -def get_all_aliases(language='zh'): - """ - 获取所有URL别名 - - Args: - language (str): 语言,'zh' 或 'en' - - Returns: - dict: 所有别名的字典 - """ - if language == 'zh': - return ZH_URL_ALIASES.copy() - elif language == 'en': - return EN_URL_ALIASES.copy() - return {} diff --git a/docs/source_en/_extra/.htaccess b/docs/source_en/_extra/.htaccess deleted file mode 100644 index fc04d9a15d..0000000000 --- a/docs/source_en/_extra/.htaccess +++ /dev/null @@ -1,55 +0,0 @@ -# URL Rewrite Rules - For shortening readthedocs links -# This file will be copied to the readthedocs root directory for URL redirection - -RewriteEngine On - -# Frequently Asked Questions - Shorten URL -RewriteRule ^faq/?$ /en/latest/Instruction/Frequently-asked-questions.html [R=301,L] -RewriteRule ^faq/(.*)$ /en/latest/Instruction/Frequently-asked-questions.html#$1 [R=301,L] - -# Supported Models and Datasets -RewriteRule ^models/?$ /en/latest/Instruction/Supported-models-and-datasets.html [R=301,L] -RewriteRule ^models/(.*)$ /en/latest/Instruction/Supported-models-and-datasets.html#$1 [R=301,L] - -# Command Line Parameters -RewriteRule ^params/?$ /en/latest/Instruction/Command-line-parameters.html [R=301,L] -RewriteRule ^params/(.*)$ /en/latest/Instruction/Command-line-parameters.html#$1 [R=301,L] - -# Custom Dataset -RewriteRule ^custom-dataset/?$ /en/latest/Customization/Custom-dataset.html [R=301,L] -RewriteRule ^custom-dataset/(.*)$ /en/latest/Customization/Custom-dataset.html#$1 [R=301,L] - -# Inference and Deployment -RewriteRule ^deploy/?$ /en/latest/Instruction/Inference-and-deployment.html [R=301,L] -RewriteRule ^deploy/(.*)$ /en/latest/Instruction/Inference-and-deployment.html#$1 [R=301,L] - -# Evaluation -RewriteRule ^eval/?$ /en/latest/Instruction/Evaluation.html [R=301,L] -RewriteRule ^eval/(.*)$ /en/latest/Instruction/Evaluation.html#$1 [R=301,L] - -# Pre-training and Fine-tuning -RewriteRule ^training/?$ /en/latest/Instruction/Pre-training-and-fine-tuning.html [R=301,L] -RewriteRule ^training/(.*)$ /en/latest/Instruction/Pre-training-and-fine-tuning.html#$1 [R=301,L] - -# SWIFT Installation -RewriteRule ^install/?$ /en/latest/GetStarted/SWIFT-installation.html [R=301,L] -RewriteRule ^install/(.*)$ /en/latest/GetStarted/SWIFT-installation.html#$1 [R=301,L] - -# Quick Start -RewriteRule ^quickstart/?$ /en/latest/GetStarted/Quick-start.html [R=301,L] -RewriteRule ^quickstart/(.*)$ /en/latest/GetStarted/Quick-start.html#$1 [R=301,L] - -# Multi-Modal -RewriteRule ^multimodal/?$ /en/latest/Multi-Modal/index.html [R=301,L] -RewriteRule ^multimodal/(.*)$ /en/latest/Multi-Modal/$1 [R=301,L] - -# Reinforcement Learning -RewriteRule ^rl/?$ /en/latest/RLHF/index.html [R=301,L] -RewriteRule ^rl/(.*)$ /en/latest/RLHF/$1 [R=301,L] - -# Customization -RewriteRule ^custom/?$ /en/latest/Customization/index.html [R=301,L] -RewriteRule ^custom/(.*)$ /en/latest/Customization/$1 [R=301,L] - -# If none of the above rules match, redirect to homepage -RewriteRule ^.*$ /en/latest/ [R=301,L] diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index c6ff497bd6..c44a8e7a19 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -169,9 +169,9 @@ def _check_not_supported(): raise ValueError('num_iterations > 1 is not supported for Megatron-GRPO yet, please set it to 1.') def _check_batch_params(): - assert self.micro_batch_size % self.num_generations == 0, \ - f'micro_batch_size ({self.micro_batch_size}) must be divisible' \ - f' by the number of generations ({self.num_generations})' + # assert self.micro_batch_size % self.num_generations == 0, \ + # f'micro_batch_size ({self.micro_batch_size}) must be divisible' \ + # f' by the number of generations ({self.num_generations})' if self.generation_batch_size is None and self.steps_per_generation is None: self.steps_per_generation = 1 self.generation_batch_size = self.global_batch_size * self.steps_per_generation diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 72f30f9f44..a0bb85955b 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -123,80 +123,6 @@ def __init__( logger.warning('`apply_rope_fusion` does not support `attention_scaling`. ' f'Setting `config.apply_rope_fusion`: {config.apply_rope_fusion}') - # Set tensor_model_parallel attributes for all parameters - # This is needed for mbridge to correctly identify TP parameters - # self._set_tensor_model_parallel_attributes() - - def _set_tensor_model_parallel_attributes(self): - """Set tensor_model_parallel attributes for all parameters. - - This method ensures that all parameters have the correct tensor_model_parallel - attributes set, which is required for mbridge to correctly identify TP parameters - during weight export. - """ - from megatron.core.tensor_parallel.layers import set_tensor_model_parallel_attributes - - # Get tensor parallel size - from megatron.core import parallel_state - tp_size = parallel_state.get_tensor_model_parallel_world_size() - - if tp_size <= 1: - return # No tensor parallelism, no need to set attributes - - # Set attributes for all parameters - for name, param in self.named_parameters(): - if not hasattr(param, 'tensor_model_parallel'): - # Determine if this parameter should be tensor parallel - is_tp_param = self._is_tensor_parallel_parameter(name, param) - if is_tp_param: - # Determine partition dimension based on parameter name - partition_dim = self._get_partition_dimension(name, param) - set_tensor_model_parallel_attributes(param, True, partition_dim, 1) - else: - # Set default attributes for non-TP parameters - setattr(param, 'tensor_model_parallel', False) - setattr(param, 'partition_dim', -1) - setattr(param, 'partition_stride', 1) - - def _is_tensor_parallel_parameter(self, name: str, param) -> bool: - """Determine if a parameter should be tensor parallel based on its name and shape.""" - # Parameters that are typically tensor parallel - tp_patterns = [ - 'weight', # Linear layer weights - 'qkv_proj.weight', # QKV projection weights - 'dense.weight', # Dense layer weights - 'fc1.weight', # MLP first layer weights - 'fc2.weight', # MLP second layer weights - 'gate_proj.weight', # Gate projection weights - 'up_proj.weight', # Up projection weights - 'down_proj.weight', # Down projection weights - ] - - # Check if parameter name matches any TP pattern - for pattern in tp_patterns: - if pattern in name: - return True - - # Special cases for bias parameters in TP layers - if 'bias' in name and any(pattern in name for pattern in tp_patterns): - return True - - return False - - def _get_partition_dimension(self, name: str, param) -> int: - """Get the partition dimension for a tensor parallel parameter.""" - # Column parallel layers (partition along output dimension) - if any(pattern in name - for pattern in ['qkv_proj.weight', 'gate_proj.weight', 'up_proj.weight', 'fc1.weight', 'dense.weight']): - return 0 # Partition along output dimension - - # Row parallel layers (partition along input dimension) - if any(pattern in name for pattern in ['down_proj.weight', 'fc2.weight']): - return 1 # Partition along input dimension - - # Default to partition along output dimension - return 0 - @contextmanager def _patch_apply_rotary_pos_emb(self): if self.attention_scaling == 1.: diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 60946a5e13..81825ea736 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -1,7 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import gc import inspect -from collections import namedtuple from contextlib import contextmanager, nullcontext from copy import copy, deepcopy from functools import partial @@ -21,7 +20,7 @@ from swift.utils import get_current_device, get_logger, is_vllm_available, remove_response from ..argument import MegatronRLHFArguments from .rlhf_base import MegatronRLHFTrainer -from .utils import gather, gather_dict, gather_object, process_packed_seq_params, profiling_context +from .utils import gather, gather_object, process_packed_seq_params, profiling_context try: from mbridge import AutoBridge @@ -39,12 +38,28 @@ def __init__(self, args: MegatronRLHFArguments, template: Template): self.hf_model_dir = args.model_info.model_dir self.processing_class = self.template.processor # TODO: multi turn scheduler(colocate multi turn) + self._prepare_template_data_collator() self._init_grpo_params() self._prepare_rewards() self._prepare_rollout_engine() # debug: use mbridge to convert mcore to hf self.bridge = None + def _prepare_template_data_collator(self): + template = self.template + args = self.args + data_collator = template.data_collator + padding_to = None + if args.tensor_model_parallel_size > 1 and args.sequence_parallel: + padding_to = args.tensor_model_parallel_size + if args.context_parallel_size > 1: + padding_to = (padding_to or 1) * args.context_parallel_size + if args.fp8_format: + padding_to = max((padding_to or 1) * 8, 16) + logger.info(f'padding_to: {padding_to}') + data_collator = partial(data_collator, padding_to=padding_to) + template.data_collator = data_collator + def _init_grpo_params(self): args = self.args # distributed params @@ -210,98 +225,26 @@ def _prepare_rewards(self): assert self.reward_funcs, 'reward_funcs is not set' - def forward_step(self, data_iterator, model): - # train_batch_size - # return: output_tensor, loss_func - - data = next(data_iterator) - inputs = { - k: v - for k, v in data.items() if k not in - ['completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', 'truncated_mask'] - } - - with self.stimer: - output_tensor = model(**inputs) - return output_tensor, partial(self.loss_func, data=data) - - def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): - advantages = data['advantages'] - labels = data['labels'] - completion_mask = data['completion_mask'] - packed_seq_params = data['packed_seq_params'] - - per_token_logps = self.get_logps(output_tensor, labels, packed_seq_params, per_token=True) - if self.beta != 0.0: - ref_per_token_logps = data.get('ref_per_token_logps') - per_token_kl = ( - torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1) - - old_per_token_logps = ( - per_token_logps.detach() if data.get('old_per_token_logps') is None else data['old_per_token_logps']) - log_ratio = per_token_logps - old_per_token_logps - - if self.importance_sampling_level == 'token': - log_importance_weights = log_ratio - elif self.importance_sampling_level == 'sequence': - log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) - log_importance_weights = log_importance_weights.unsqueeze(-1) - elif self.importance_sampling_level == 'sequence_token': - # GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)] - seq_level_log_weight = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) - seq_level_log_weight = seq_level_log_weight.detach().unsqueeze(-1) # Stop gradient - log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight - else: - raise ValueError( - f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " - "and 'sequence'.") - - coef_1 = torch.exp(log_importance_weights) - coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) - if self.args.delta is not None: - coef_1 = torch.clamp(coef_1, max=self.args.delta) - - if self.template.padding_free: - advantages = advantages[-coef_1.shape[1]:] - per_token_loss1 = coef_1 * advantages.unsqueeze(0) - per_token_loss2 = coef_2 * advantages.unsqueeze(0) - else: - per_token_loss1 = coef_1 * advantages.unsqueeze(1) - per_token_loss2 = coef_2 * advantages.unsqueeze(1) - per_token_loss = -torch.min(per_token_loss1, per_token_loss2) - if self.beta != 0.0: - per_token_loss = per_token_loss + self.beta * per_token_kl - - if self.loss_type == 'grpo': - loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() - elif self.loss_type == 'bnpo': - loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) - elif self.loss_type == 'dr_grpo': - loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) - else: - raise ValueError(f'Unknown loss type: {self.loss_type}') - - loss = loss.mean() - metric = { - 'loss': loss.clone().detach(), - } - reporting_metric = loss.new_tensor(list(metric.values())) - torch.distributed.all_reduce( - reporting_metric, torch.distributed.ReduceOp.AVG, group=mpu.get_data_parallel_group()) - reporting_metric = {k: reporting_metric[i] for i, k in enumerate(metric.keys())} - # fix megatron-lm bug - # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.12.0/megatron/core/pipeline_parallel/schedules.py#L291 - loss = loss / mpu.get_context_parallel_world_size() - return loss, reporting_metric - def _replace_data_iterator(self, data_iterator): args = get_args() if args.iteration % self.steps_per_generation == 0: # each rollout DP group will generate generation_batch_size / world_size completions - num_iters_per_step = self.generation_batch_size // mpu.get_data_parallel_world_size() + completions_to_rollout = self.generation_batch_size // mpu.get_data_parallel_world_size() # completions will be repeated num_generations times after - num_iters_per_step = num_iters_per_step // self.num_generations + # so we need to divide num_iters_per_step by num_generations to get prompt batch size + prompts_to_rollout = completions_to_rollout // self.num_generations + # every iter will generate micro_batch_size prompts + num_iters_per_step = prompts_to_rollout // self.micro_batch_size + assert num_iters_per_step > 0, ( + f'num_iters_per_step={num_iters_per_step} <= 0. ' + f'This means no prompts will be generated' + f'generation_batch_size={self.generation_batch_size}, ' + f'data_parallel_world_size={mpu.get_data_parallel_world_size()}, ' + f'num_generations={self.num_generations}, ' + f'micro_batch_size={self.micro_batch_size}. ' + 'Please adjust these parameters so that ' + 'generation_batch_size // data_parallel_world_size // num_generations // micro_batch_size >= 1.') rollout_batch = [] for _ in range(self.steps_per_generation): for _ in range(num_iters_per_step): @@ -354,12 +297,26 @@ def _get_encoded_batch(rollout_batch, advantages): position_ids = encoded_batch.get('position_ids') squeezed_position_ids = position_ids.squeeze() assert squeezed_position_ids is not None - + # Remove trailing padding zeros from position_ids to avoid interference + # Find the last non-zero position + last_nonzero_idx = (squeezed_position_ids != 0).nonzero(as_tuple=True)[0] + if len(last_nonzero_idx) > 0: + # Keep only up to the last non-zero position + 1 to include the last valid position + squeezed_position_ids = squeezed_position_ids[:last_nonzero_idx[-1] + 1] + + # Calculate lengths based on sequence boundaries (position_ids == 0) lengths = torch.diff( torch.cat([(squeezed_position_ids == 0).nonzero(as_tuple=True)[0], torch.tensor([len(squeezed_position_ids)]).to(squeezed_position_ids.device)])) advantages = torch.repeat_interleave(advantages, lengths) + # Pad advantages to match the original position_ids length + original_length = position_ids.shape[1] + if advantages.shape[0] < original_length: + padding_length = original_length - advantages.shape[0] + padding = torch.zeros(padding_length, device=advantages.device, dtype=advantages.dtype) + advantages = torch.cat([advantages, padding]) + encoded_batch.update({ 'completion_mask': labels != -100, @@ -603,16 +560,18 @@ def maybe_normalize_advantages(advantages: torch.Tensor, rewards_std: torch.Tens return advantages - def _maybe_compute_logps(self, batch: DataType) -> DataType: + def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: # TODO: entropy + inputs = {k: v for k, v in batch.items() if k not in ['completion_mask', 'advantages', 'truncated_mask']} + if self.beta != 0.0: with torch.no_grad(), self.null_ref_context() as ref_model: batch['ref_per_token_logps'] = self.model_forward( - ref_model, iter([batch]), no_grad=True, per_token=True)['logps'] + ref_model, iter([inputs]), no_grad=True, per_token=True)['logps'] if not self.on_policy: batch['old_per_token_logps'] = self.model_forward( - self.unwrapped_model, iter([batch]), no_grad=True, per_token=True)['logps'] + self.unwrapped_model, iter([inputs]), no_grad=True, per_token=True)['logps'] # get packed_seq_params, from get_batch func batch = process_packed_seq_params(batch) @@ -657,7 +616,7 @@ def patch_megatron_data_collator(self, data_collator): def build_pretraining_data_loader(*_args, **kwargs): args = get_args() org_micro_batch_size = args.micro_batch_size - args.micro_batch_size = org_micro_batch_size // self.num_generations + # args.micro_batch_size = org_micro_batch_size // self.num_generations res = origin_build_pretraining_data_loader(*_args, **kwargs) args.micro_batch_size = org_micro_batch_size if res is not None and args.dataloader_type != 'external': @@ -669,3 +628,86 @@ def build_pretraining_data_loader(*_args, **kwargs): yield finally: training.build_pretraining_data_loader = origin_build_pretraining_data_loader + + def forward_step(self, data_iterator, model): + # train_batch_size + # return: output_tensor, loss_func + + data = next(data_iterator) + inputs = { + k: v + for k, v in data.items() if k not in + ['completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', 'truncated_mask'] + } + + with self.stimer: + output_tensor = model(**inputs) + return output_tensor, partial(self.loss_func, data=data) + + def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): + advantages = data['advantages'] + labels = data['labels'] + completion_mask = data['completion_mask'] + packed_seq_params = data['packed_seq_params'] + per_token_logps = self.get_logps(output_tensor, labels, packed_seq_params, per_token=True) + if self.beta != 0.0: + ref_per_token_logps = data.get('ref_per_token_logps') + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1) + + old_per_token_logps = ( + per_token_logps.detach() if data.get('old_per_token_logps') is None else data['old_per_token_logps']) + log_ratio = per_token_logps - old_per_token_logps + + if self.importance_sampling_level == 'token': + log_importance_weights = log_ratio + elif self.importance_sampling_level == 'sequence': + log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + log_importance_weights = log_importance_weights.unsqueeze(-1) + elif self.importance_sampling_level == 'sequence_token': + # GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)] + seq_level_log_weight = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + seq_level_log_weight = seq_level_log_weight.detach().unsqueeze(-1) # Stop gradient + log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight + else: + raise ValueError( + f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " + "and 'sequence'.") + + coef_1 = torch.exp(log_importance_weights) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + if self.args.delta is not None: + coef_1 = torch.clamp(coef_1, max=self.args.delta) + + if self.template.padding_free: + advantages = advantages[-coef_1.shape[1]:] + per_token_loss1 = coef_1 * advantages.unsqueeze(0) + per_token_loss2 = coef_2 * advantages.unsqueeze(0) + else: + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl + + if self.loss_type == 'grpo': + loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() + elif self.loss_type == 'bnpo': + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + elif self.loss_type == 'dr_grpo': + loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) + else: + raise ValueError(f'Unknown loss type: {self.loss_type}') + + loss = loss.mean() + metric = { + 'loss': loss.clone().detach(), + } + reporting_metric = loss.new_tensor(list(metric.values())) + torch.distributed.all_reduce( + reporting_metric, torch.distributed.ReduceOp.AVG, group=mpu.get_data_parallel_group()) + reporting_metric = {k: reporting_metric[i] for i, k in enumerate(metric.keys())} + # fix megatron-lm bug + # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.12.0/megatron/core/pipeline_parallel/schedules.py#L291 + loss = loss / mpu.get_context_parallel_world_size() + return loss, reporting_metric diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index b8f81bd58f..33b945cb2c 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -159,31 +159,6 @@ def profiling_context(trainer, name: str): # TODO: add swanlab support -def gather_dict(tensors: Dict[str, torch.Tensor], group: torch.distributed.ProcessGroup): - if not isinstance(tensors, dict): - raise ValueError(f'Expected a dictionary, got {type(tensors)}') - size = torch.distributed.get_world_size(group=group) - - output = {} - sorted_keys = sorted(tensors.keys()) - for key in sorted_keys: - val = tensors[key] - if isinstance(val, int): - # num_samples - output[key] = val - continue - elif isinstance(val, torch.Tensor): - output[key] = [torch.empty_like(val) for _ in range(size)] - torch.distributed.all_gather(output[key], val, group=group, async_op=False) - output[key] = torch.cat(output[key], dim=0) - else: - output[key] = [None for _ in range(size)] - torch.distributed.all_gather_object(output[key], val, group=group, async_op=False) - output[key] = [item for sublist in output[key] for item in sublist] - - return output - - def gather(tensor, group: Optional[torch.distributed.ProcessGroup] = None): if group is None: return hf_gather(tensor) @@ -200,12 +175,5 @@ def gather_object(object: Any, group: Optional[torch.distributed.ProcessGroup] = size = torch.distributed.get_world_size(group=group) output_objects = [None for _ in range(size)] torch.distributed.all_gather_object(output_objects, object) - # all_gather_object returns a list of lists, so we need to flatten it + # flatten return [x for y in output_objects for x in y] - - -def make_batch_generator(batch: List[Dict[str, Any]], batch_size: int): - assert batch_size > 0, 'batch_size must be positive' - assert len(batch) % batch_size == 0, 'batch length must be a multiple of batch_size' - for i in range(0, len(batch), batch_size): - yield batch[i:i + batch_size] From 8627aa33d568162dbd43fc918f641ec4ca45f6aa Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 29 Sep 2025 19:23:56 +0800 Subject: [PATCH 16/83] compute loss --- swift/megatron/trainers/grpo_trainer.py | 39 ++++++++++++++++++++----- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 81825ea736..bbc7bd8910 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -649,7 +649,24 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): labels = data['labels'] completion_mask = data['completion_mask'] packed_seq_params = data['packed_seq_params'] + truncated_mask = data['truncated_mask'] + micro_batch_size = self.micro_batch_size + lengths = packed_seq_params.cu_seqlens_q[1:micro_batch_size + + 1] - packed_seq_params.cu_seqlens_q[:micro_batch_size] + lengths_with_padding = packed_seq_params.cu_seqlens_q[1:] - packed_seq_params.cu_seqlens_q[:-1] per_token_logps = self.get_logps(output_tensor, labels, packed_seq_params, per_token=True) + + if self.args.overlong_filter and any(truncated_mask): + # TODO: non-padding-free + truncated_mask = torch.repeat_interleave(truncated_mask, lengths).unsqueeze(0) + padding_length = completion_mask.shape[1] - truncated_mask.shape[1] + if padding_length > 0: + padding = torch.zeros(padding_length, device=truncated_mask.device, dtype=truncated_mask.dtype) + truncated_mask = torch.cat([truncated_mask, padding]) + completion_mask = completion_mask & (~truncated_mask) + else: + raise NotImplementedError # TODO + if self.beta != 0.0: ref_per_token_logps = data.get('ref_per_token_logps') per_token_kl = ( @@ -662,13 +679,16 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): if self.importance_sampling_level == 'token': log_importance_weights = log_ratio elif self.importance_sampling_level == 'sequence': - log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) - log_importance_weights = log_importance_weights.unsqueeze(-1) - elif self.importance_sampling_level == 'sequence_token': - # GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)] - seq_level_log_weight = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) - seq_level_log_weight = seq_level_log_weight.detach().unsqueeze(-1) # Stop gradient - log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight + log_ratio_list = torch.split(log_ratio.squeeze(0), lengths_with_padding.tolist()) + mask_list = torch.split(completion_mask.squeeze(0), lengths_with_padding.tolist()) + seq_weights = [(lr * m).sum() / m.sum().clamp(min=1.0) for lr, m in zip(log_ratio_list, mask_list)] + seq_level_log_weights = torch.stack(seq_weights).to(log_ratio.dtype).unsqueeze(-1) + if self.importance_sampling_level == 'sequence': + log_importance_weights = seq_level_log_weights + else: + seq_level_log_weight = seq_level_log_weights.detach() + seq_level_log_weight = torch.repeat_interleave(seq_level_log_weight, lengths).unsqueeze(0) + log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight else: raise ValueError( f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " @@ -691,6 +711,11 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): per_token_loss = per_token_loss + self.beta * per_token_kl if self.loss_type == 'grpo': + loss_list = torch.split(per_token_loss.squeeze(0), lengths_with_padding.tolist()) + mask_list = torch.split(completion_mask.squeeze(0), lengths_with_padding.tolist()) + sample_loss = [(loss * mask).sum() / mask.sum().clamp(min=1.0) for loss, mask in zip(loss_list, mask_list)] + loss = torch.stack(sample_loss[:micro_batch_size]).mean() + loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() elif self.loss_type == 'bnpo': loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) From 2292cf84809dfcdcc8f33f7d2068a024279498fb Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 30 Sep 2025 11:35:33 +0800 Subject: [PATCH 17/83] fix logps --- swift/megatron/trainers/grpo_trainer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index bbc7bd8910..0b96288a7f 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -20,7 +20,7 @@ from swift.utils import get_current_device, get_logger, is_vllm_available, remove_response from ..argument import MegatronRLHFArguments from .rlhf_base import MegatronRLHFTrainer -from .utils import gather, gather_object, process_packed_seq_params, profiling_context +from .utils import gather, gather_object, get_batch, process_packed_seq_params, profiling_context try: from mbridge import AutoBridge @@ -574,7 +574,7 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: self.unwrapped_model, iter([inputs]), no_grad=True, per_token=True)['logps'] # get packed_seq_params, from get_batch func - batch = process_packed_seq_params(batch) + # batch = process_packed_seq_params(batch) return batch @@ -632,8 +632,8 @@ def build_pretraining_data_loader(*_args, **kwargs): def forward_step(self, data_iterator, model): # train_batch_size # return: output_tensor, loss_func - - data = next(data_iterator) + data = get_batch(data_iterator) + data.pop('loss_scale', None) inputs = { k: v for k, v in data.items() if k not in @@ -664,8 +664,6 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): padding = torch.zeros(padding_length, device=truncated_mask.device, dtype=truncated_mask.dtype) truncated_mask = torch.cat([truncated_mask, padding]) completion_mask = completion_mask & (~truncated_mask) - else: - raise NotImplementedError # TODO if self.beta != 0.0: ref_per_token_logps = data.get('ref_per_token_logps') From bbe5f39da71c16514c551cf394e0a971681d52b2 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 1 Oct 2025 01:48:07 +0800 Subject: [PATCH 18/83] logging & patch VL --- swift/megatron/trainers/grpo_trainer.py | 88 +++++++++++++++++++---- swift/megatron/trainers/rlhf_base.py | 4 +- swift/megatron/trainers/utils.py | 92 ++++++++++++++++++++++++- 3 files changed, 164 insertions(+), 20 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 0b96288a7f..03887abd13 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -1,15 +1,18 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import gc import inspect +from collections import defaultdict from contextlib import contextmanager, nullcontext from copy import copy, deepcopy from functools import partial +from types import MethodType from typing import Any, Dict, List, Union import torch import torch.nn as nn from megatron.core import mpu from megatron.training import get_args, training +from trl.trainer.grpo_trainer import nanstd from vllm.distributed import parallel_state as vllm_ps from swift.llm import RequestConfig, RowPreprocessor, Template, to_device @@ -18,7 +21,7 @@ from swift.trainers.rlhf_trainer.grpo_trainer import DataType from swift.trainers.rlhf_trainer.utils import replace_assistant_response_with_ids from swift.utils import get_current_device, get_logger, is_vllm_available, remove_response -from ..argument import MegatronRLHFArguments +from ..argument import MegatronArguments, MegatronRLHFArguments from .rlhf_base import MegatronRLHFTrainer from .utils import gather, gather_object, get_batch, process_packed_seq_params, profiling_context @@ -44,6 +47,7 @@ def __init__(self, args: MegatronRLHFArguments, template: Template): self._prepare_rollout_engine() # debug: use mbridge to convert mcore to hf self.bridge = None + self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} def _prepare_template_data_collator(self): template = self.template @@ -61,7 +65,7 @@ def _prepare_template_data_collator(self): template.data_collator = data_collator def _init_grpo_params(self): - args = self.args + args: MegatronArguments = self.args # distributed params self.world_size = torch.distributed.get_world_size() self.process_index = torch.distributed.get_rank() @@ -139,6 +143,7 @@ def _init_rollout_group(self): for group_ranks in rollout_groups: if self.process_index in group_ranks: self.rollout_group = torch.distributed.new_group(ranks=group_ranks) + print(f'rank {self.process_index} join rollout group with ranks: {group_ranks}') break def prepare_vllm(self): @@ -175,6 +180,7 @@ def _move_model_to_vllm(self): # TODO: LoRA, server if self.bridge is None: self.bridge = AutoBridge.from_pretrained(self.hf_model_dir) + self._patch_mbridge(self.bridge) per_tensor_params = self.bridge.export_weights([self.unwrapped_model]) self.engine.inner_model.load_weights(per_tensor_params) # TODO: check tensor_model_parallel @@ -225,6 +231,19 @@ def _prepare_rewards(self): assert self.reward_funcs, 'reward_funcs is not set' + def _patch_mbridge(self, bridge): + original_method = bridge._weight_to_hf_format + + def _weight_to_hf_format_patched(mcore_weights_name, mcore_weights): + # skip ViT weights + if 'visual' in mcore_weights_name: + if 'visual.visual' in mcore_weights_name: + mcore_weights_name = mcore_weights_name.replace('visual.visual', 'visual') + return [mcore_weights_name], [mcore_weights] + return original_method(mcore_weights_name, mcore_weights) + + bridge._weight_to_hf_format = _weight_to_hf_format_patched + def _replace_data_iterator(self, data_iterator): args = get_args() @@ -246,9 +265,8 @@ def _replace_data_iterator(self, data_iterator): 'Please adjust these parameters so that ' 'generation_batch_size // data_parallel_world_size // num_generations // micro_batch_size >= 1.') rollout_batch = [] - for _ in range(self.steps_per_generation): - for _ in range(num_iters_per_step): - rollout_batch.extend(next(data_iterator)) + for _ in range(num_iters_per_step): + rollout_batch.extend(next(data_iterator)) micro_batch_data = self._generate_and_score_completions(rollout_batch) num_mini_batch = self.global_batch_size // (self.micro_batch_size * mpu.get_data_parallel_world_size()) mini_batch_data = [ @@ -324,8 +342,6 @@ def _get_encoded_batch(rollout_batch, advantages): torch.tensor([b['is_truncated'] for b in rollout_batch], dtype=torch.bool, device=self.device), 'advantages': advantages, - 'position_ids': - position_ids # remove it: non-padding-free }) return encoded_batch @@ -481,10 +497,6 @@ def _get_request_config(self) -> RequestConfig: def _colocate_rollout(self, batch, request_config: RequestConfig): outputs: List[RolloutOutput] = self.engine.infer(infer_requests=batch, request_config=request_config) - completions = [output.response.choices[0].message.content for output in outputs] - if self.process_index == 0: - for completion in completions: - print(completion) return outputs def _score_completions(self, inputs: DataType) -> torch.Tensor: @@ -554,6 +566,28 @@ def maybe_normalize_advantages(advantages: torch.Tensor, rewards_std: torch.Tens advantages = rewards - group_rewards_mean advantages = maybe_normalize_advantages(advantages, group_rewards_std) + def log_rewards_metrics(rewards: torch.Tensor, rewards_per_func_for_metrics: torch.Tensor): + """Log reward statistics for monitoring. Only log once per unique request_id.""" + # rewards: [prompt_batch_size, self.num_generations] + # rewards_per_func_for_metrics: [prompt_batch_size*self.num_generations, self.num_reward_funcs] + mode = 'train' if self.unwrapped_model.training else 'eval' + group_rewards = rewards.view(-1, self.num_generations) + rewards_mean = group_rewards.mean(-1).mean().item() + rewards_std = group_rewards.std(-1).mean().item() + is_std_zero = torch.isclose(group_rewards.std(dim=1), torch.zeros_like(group_rewards.std(dim=1))) + + self._metrics[mode]['reward'].append(rewards_mean) + self._metrics[mode]['reward_std'].append(rewards_std) + self._metrics[mode]['frac_reward_zero_std'].append(is_std_zero.float().mean().item()) + + # Log per-reward-function statistics using deduplicated rewards_per_func + for i, name in enumerate(self.reward_func_names): + col = rewards_per_func_for_metrics[:, i] + self._metrics[mode][f'rewards/{name}/mean'].append(torch.nanmean(col).item()) + self._metrics[mode][f'rewards/{name}/std'].append(nanstd(col).item()) + + log_rewards_metrics(rewards=grouped_rewards, rewards_per_func_for_metrics=rewards_per_func) + slice_start = self.process_index * len(batch) slice_end = slice_start + len(batch) advantages = advantages[slice_start:slice_end] @@ -723,13 +757,37 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): raise ValueError(f'Unknown loss type: {self.loss_type}') loss = loss.mean() - metric = { + avg_metric = { 'loss': loss.clone().detach(), + 'completions/mean_length': lengths.float().mean(), } - reporting_metric = loss.new_tensor(list(metric.values())) + max_metric = { + 'completions/max_length': lengths.float().max(), + } + min_metric = { + 'completions/min_length': lengths.float().min(), + } + if self.beta != 0.0: + avg_metric['kl'] = per_token_kl.mean().item() + avg_reporting_metric = loss.new_tensor(list(avg_metric.values())) + max_reporting_metric = loss.new_tensor(list(max_metric.values())) + min_reporting_metric = loss.new_tensor(list(min_metric.values())) + torch.distributed.all_reduce( + avg_reporting_metric, torch.distributed.ReduceOp.AVG, group=mpu.get_data_parallel_group()) + + torch.distributed.all_reduce( + max_reporting_metric, torch.distributed.ReduceOp.MAX, group=mpu.get_data_parallel_group()) torch.distributed.all_reduce( - reporting_metric, torch.distributed.ReduceOp.AVG, group=mpu.get_data_parallel_group()) - reporting_metric = {k: reporting_metric[i] for i, k in enumerate(metric.keys())} + min_reporting_metric, torch.distributed.ReduceOp.MIN, group=mpu.get_data_parallel_group()) + avg_reporting_metric = {k: avg_reporting_metric[i] for i, k in enumerate(avg_metric.keys())} + max_reporting_metric = {k: max_reporting_metric[i] for i, k in enumerate(max_metric.keys())} + min_reporting_metric = {k: min_reporting_metric[i] for i, k in enumerate(min_metric.keys())} + addition_metrics = { + key: torch.tensor(sum(val) / len(val), device=loss.device) + for key, val in self._metrics['train'].items() + } + + reporting_metric = {**avg_reporting_metric, **max_reporting_metric, **min_reporting_metric, **addition_metrics} # fix megatron-lm bug # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.12.0/megatron/core/pipeline_parallel/schedules.py#L291 loss = loss / mpu.get_context_parallel_world_size() diff --git a/swift/megatron/trainers/rlhf_base.py b/swift/megatron/trainers/rlhf_base.py index 2a7064081d..569854274e 100644 --- a/swift/megatron/trainers/rlhf_base.py +++ b/swift/megatron/trainers/rlhf_base.py @@ -13,7 +13,7 @@ from swift.utils import get_current_device, get_logger from .trainer import MegatronTrainer -from .utils import get_batch +from .utils import get_batch, load_megatron_model_to_gpu, offload_megatron_model_to_cpu logger = get_logger() @@ -156,6 +156,6 @@ def null_ref_context(self): self.peft_model.set_adapter('default') @contextmanager - def offload_context(self): + def offload_context(self, model): # TODO: offload yield diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 33b945cb2c..6945cf8daa 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import gc import time from contextlib import contextmanager from typing import Any, Dict, List, Optional @@ -7,12 +8,14 @@ from accelerate.utils import gather as hf_gather from accelerate.utils import gather_object as hf_gather_object from megatron.core import mpu +from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.utils import get_batch_on_this_cp_rank as mcore_get_batch_on_this_cp_rank from megatron.training import get_args, get_wandb_writer from swift.llm import get_packed_seq_params as _get_packed_seq_params from swift.llm import to_device +from swift.utils.torch_utils import empty_cache, get_current_device def get_swift_datasets_provider(train_dataset, val_dataset): @@ -173,7 +176,90 @@ def gather_object(object: Any, group: Optional[torch.distributed.ProcessGroup] = if group is None: return hf_gather_object(object) size = torch.distributed.get_world_size(group=group) + rank = torch.distributed.get_rank(group=group) output_objects = [None for _ in range(size)] - torch.distributed.all_gather_object(output_objects, object) - # flatten - return [x for y in output_objects for x in y] + + try: + # 添加调试信息 + from swift.utils import get_logger + logger = get_logger() + logger.info(f'Rank {rank}/{size} in group starting all_gather_object with {len(object)} objects') + + torch.distributed.all_gather_object(output_objects, object, group=group) + + logger.info(f'Rank {rank}/{size} in group completed all_gather_object successfully') + # flatten + return [x for y in output_objects for x in y] + except Exception as e: + from swift.utils import get_logger + logger = get_logger() + logger.error(f'Rank {rank}/{size} in group failed at all_gather_object: {e}') + logger.error(f"Object size: {len(object) if hasattr(object, '__len__') else 'unknown'}") + if torch.cuda.is_available(): + logger.error(f'GPU memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, ' + f'{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved') + raise + + +# code borrowed from VeRL +@torch.no_grad() +def load_megatron_model_to_gpu(models, load_grad=True): + for model_chunk in models: + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] + for buffers in model_chunk_all_buffers: + for buffer in buffers: + # sometimes, we don't want to load grad for pure inference + if load_grad: + buffer.grad_data.storage().resize_(buffer.grad_data_size) + buffer.grad_data.zero_() + + if buffer.param_data.storage().size() == 0: + buffer.param_data.storage().resize_(buffer.param_data_size) + # copy data from cpu to cuda + buffer.param_data.copy_(buffer.param_data.cpu_data, non_blocking=True) + else: + # we need this for ref module + device_id = get_current_device() + for _, param in model_chunk.named_parameters(): + param.data = param.data.to(device_id, non_blocking=True) + if param.grad is not None: + param.grad = param.grad.to(device_id, non_blocking=True) + gc.collect() + empty_cache() + + +@torch.no_grad() +def offload_megatron_model_to_cpu(models): + """ + In megatron, the model and optimizer storage are: + - bf16 parameter data chunked in model parallel group + - fp32 grad chunked in model parallel group + - fp32 main_parameter chunked in model and dp group + - fp32 optimizer state chunked in model and dp group + """ + for model_chunk in models: + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] + for buffers in model_chunk_all_buffers: + for buffer in buffers: + # offload parameters + if buffer.param_data.storage().size() > 0: + buffer.param_data.cpu_data = buffer.param_data.data.cpu().pin_memory() + buffer.param_data_size = buffer.param_data.storage().size() + buffer.param_data.storage().resize_(0) + + assert buffer.param_data_size == buffer.param_data.cpu_data.storage().size() + + if buffer.grad_data.storage().size() > 0: + # if the grad_data size is already zero, we assume that it is already offloaded + buffer.grad_data_size = buffer.grad_data.storage().size() + buffer.grad_data.storage().resize_(0) + else: + # we need this for ref module + for _, param in model_chunk.named_parameters(): + param.data = param.data.to('cpu', non_blocking=True) + if param.grad is not None: + param.grad = param.grad.to('cpu', non_blocking=True) + gc.collect() + empty_cache() From 6a2940cfc8f3614d05e78f42828cff10d28dc736 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 1 Oct 2025 18:58:07 +0800 Subject: [PATCH 19/83] fix rollout_group & rollout judgement --- swift/megatron/trainers/grpo_trainer.py | 33 +++++++++---------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 03887abd13..97a8254ed0 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -100,6 +100,8 @@ def _init_grpo_params(self): stop=args.stop_words, return_details=True) + self._step = 0 + def _prepare_rollout_engine(self): args = self.args self.vllm_mode = args.vllm_mode @@ -132,20 +134,6 @@ def _prepare_rollout_engine(self): if self.args.sleep_level > 0: self.engine.engine.sleep(self.args.sleep_level) - self._init_rollout_group() - - def _init_rollout_group(self): - args = self.args - model_size = args.tensor_model_parallel_size * args.pipeline_model_parallel_size * args.context_parallel_size - # each model share the rollout group (gather) - rollout_groups = [list(range(i, i + model_size)) for i in range(0, self.world_size, model_size)] - - for group_ranks in rollout_groups: - if self.process_index in group_ranks: - self.rollout_group = torch.distributed.new_group(ranks=group_ranks) - print(f'rank {self.process_index} join rollout group with ranks: {group_ranks}') - break - def prepare_vllm(self): from swift.llm.infer.infer_engine import GRPOVllmEngine args = self.args @@ -247,7 +235,7 @@ def _weight_to_hf_format_patched(mcore_weights_name, mcore_weights): def _replace_data_iterator(self, data_iterator): args = get_args() - if args.iteration % self.steps_per_generation == 0: + if args._step % self.steps_per_generation == 0: # each rollout DP group will generate generation_batch_size / world_size completions completions_to_rollout = self.generation_batch_size // mpu.get_data_parallel_world_size() # completions will be repeated num_generations times after @@ -274,18 +262,20 @@ def _replace_data_iterator(self, data_iterator): ] assert len(mini_batch_data) == self.steps_per_generation self._buffered_inputs = mini_batch_data - - inputs = self._buffered_inputs[args.iteration % self.steps_per_generation] + self._step += 1 + inputs = self._buffered_inputs[self._step % self.steps_per_generation] return iter(inputs) def _generate_and_score_completions(self, batch): + rollout_group = mpu.get_model_parallel_group() + # batch : same across DP groups def get_local_rollout_batch(batch): # repeat num_generations times global_rollout_batch = [deepcopy(item) for item in batch for _ in range(self.num_generations)] # get local rollout data - rollout_rank = torch.distributed.get_rank(group=self.rollout_group) - rollout_group_size = torch.distributed.get_world_size(group=self.rollout_group) + rollout_rank = torch.distributed.get_rank(group=rollout_group) + rollout_group_size = torch.distributed.get_world_size(group=rollout_group) per_device_batch_size = self.per_device_generation_batch_size assert rollout_group_size * per_device_batch_size == len(global_rollout_batch) data_slice = slice(rollout_rank * per_device_batch_size, (rollout_rank + 1) * per_device_batch_size) @@ -347,8 +337,9 @@ def _get_encoded_batch(rollout_batch, advantages): return encoded_batch # Step2: ref/old logps - total_batch = gather_object(rollout_batch, group=self.rollout_group) - total_advantages = gather(advantages, group=self.rollout_group) + rollout_group + total_batch = gather_object(rollout_batch, group=rollout_group) + total_advantages = gather(advantages, group=rollout_group) mini_batch_data = [] for idx in range(0, len(total_batch), self.micro_batch_size): micro_batch_data = _get_encoded_batch(total_batch[idx:idx + self.micro_batch_size], From 486c3d427d1e8e8ec7bc181cf870191b47085cff Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 6 Oct 2025 17:54:59 +0800 Subject: [PATCH 20/83] fix step --- swift/megatron/trainers/grpo_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 97a8254ed0..a512bff30e 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -234,8 +234,7 @@ def _weight_to_hf_format_patched(mcore_weights_name, mcore_weights): def _replace_data_iterator(self, data_iterator): - args = get_args() - if args._step % self.steps_per_generation == 0: + if self._step % self.steps_per_generation == 0: # each rollout DP group will generate generation_batch_size / world_size completions completions_to_rollout = self.generation_batch_size // mpu.get_data_parallel_world_size() # completions will be repeated num_generations times after From c68d97606c9c2a1a34720760bfad6be5863000ea Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 7 Oct 2025 18:16:36 +0800 Subject: [PATCH 21/83] move old base trainer to newer --- swift/megatron/trainers/base.py | 48 ++++++- swift/megatron/trainers/dpo_trainer.py | 16 --- swift/megatron/trainers/grpo_trainer.py | 2 +- swift/megatron/trainers/rlhf_base.py | 161 ------------------------ 4 files changed, 47 insertions(+), 180 deletions(-) delete mode 100644 swift/megatron/trainers/rlhf_base.py diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 8c8d19aaa5..0a9c403362 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -3,7 +3,7 @@ import os import time from abc import ABC, abstractmethod -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from datetime import datetime from typing import Dict @@ -28,6 +28,7 @@ from megatron.training.training import num_floating_point_operations from megatron.training.utils import reduce_max_stat_across_model_parallel_group, report_memory, unwrap_model from packaging import version +from torch.distributed.nn import all_reduce from transformers.utils import ContextManagers from swift.llm import Template, dynamic_gradient_checkpointing @@ -35,7 +36,7 @@ from swift.trainers import SwiftMixin from swift.utils import JsonlWriter, deep_getattr, format_time, get_logger from ..utils import adapter_state_dict_context, copy_original_module_weight, prepare_mcore_model -from .utils import get_swift_datasets_provider +from .utils import get_batch, get_swift_datasets_provider logger = get_logger() @@ -71,9 +72,11 @@ def initialize_megatron(*_args, **kwargs): args = get_args() data_parallel_size = mpu.get_data_parallel_world_size() step_batch_size = args.micro_batch_size * data_parallel_size + num_generations = args.num_generations if hasattr(args, 'num_generations') else 1 if args.train_iters is None and args.max_epochs is not None: if hasattr(train_dataset, '__len__'): dataset_sample = len(train_dataset) // step_batch_size * step_batch_size + dataset_sample = dataset_sample * num_generations args.train_iters = dataset_sample * args.max_epochs // args.global_batch_size else: raise ValueError( @@ -83,6 +86,7 @@ def initialize_megatron(*_args, **kwargs): args.eval_iters = 0 elif hasattr(val_dataset, '__len__'): dataset_sample = len(val_dataset) // step_batch_size * step_batch_size + dataset_sample = dataset_sample * num_generations args.eval_iters = max(dataset_sample // args.global_batch_size, 1) else: raise ValueError( @@ -867,3 +871,43 @@ def _forward_step_helper(model, inputs): output_tensor = None return output_tensor + + def model_forward(self, model, data_iterator, no_grad=True, per_token=False): + # used to calculate model forward (logps) + with self.stimer(bdata=True): + data = get_batch(data_iterator) + data.pop('loss_scale', None) + labels = data.get('labels') + context = torch.no_grad() if no_grad else nullcontext() + with context: + output_tensor = self._forward_step_helper(model, data) + data['logps'] = None if labels is None else self.get_logps( + output_tensor, labels, data['packed_seq_params'], per_token=per_token) + return data + + @staticmethod + def get_logps(output_tensor, labels, packed_seq_params, per_token: bool = False): + args = get_args() + per_token_logps = -output_tensor + loss_mask = labels != -100 + per_token_logps = per_token_logps * loss_mask + num_samples = packed_seq_params.num_samples + if args.rlhf_type == 'dpo': + total_samples = num_samples * 2 + elif args.rlhf_type in 'grpo': + total_samples = num_samples + + cu_seqlens = packed_seq_params.cu_seqlens_q[:total_samples + 1] // args.context_parallel_size + + if per_token: + if args.context_parallel_size > 1: + per_token_logps = all_reduce(per_token_logps, group=mpu.get_context_parallel_group()) + return per_token_logps + else: + all_logps = per_token_logps.new_zeros((total_samples, )) + for i in range(total_samples): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + all_logps[i] = per_token_logps[:, start:end].sum() + if args.context_parallel_size > 1: + all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) + return all_logps diff --git a/swift/megatron/trainers/dpo_trainer.py b/swift/megatron/trainers/dpo_trainer.py index 783dea8a6e..0ded803dc2 100644 --- a/swift/megatron/trainers/dpo_trainer.py +++ b/swift/megatron/trainers/dpo_trainer.py @@ -37,22 +37,6 @@ def __init__(self, args, template): self.dummy_dpo_trainer = DummyDPOTrainer(args) self.ref_models = [] - @staticmethod - def get_logps(output_tensor, labels, packed_seq_params): - args = get_args() - per_token_logps = -output_tensor - loss_mask = labels != -100 - per_token_logps = per_token_logps * loss_mask - num_samples = packed_seq_params.num_samples - cu_seqlens = packed_seq_params.cu_seqlens_q[:num_samples * 2 + 1] // args.context_parallel_size - all_logps = per_token_logps.new_zeros((num_samples * 2, )) - for i in range(num_samples * 2): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - all_logps[i] = per_token_logps[:, start:end].sum() - if args.context_parallel_size > 1: - all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) - return all_logps - def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, packed_seq_params): ref_output_tensor = output_tensor[:output_tensor.shape[0] // 2].detach() output_tensor = output_tensor[output_tensor.shape[0] // 2:] diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index a512bff30e..4dbf581d38 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -22,7 +22,7 @@ from swift.trainers.rlhf_trainer.utils import replace_assistant_response_with_ids from swift.utils import get_current_device, get_logger, is_vllm_available, remove_response from ..argument import MegatronArguments, MegatronRLHFArguments -from .rlhf_base import MegatronRLHFTrainer +from .base import MegatronRLHFTrainer from .utils import gather, gather_object, get_batch, process_packed_seq_params, profiling_context try: diff --git a/swift/megatron/trainers/rlhf_base.py b/swift/megatron/trainers/rlhf_base.py deleted file mode 100644 index 569854274e..0000000000 --- a/swift/megatron/trainers/rlhf_base.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from collections import namedtuple -from contextlib import contextmanager, nullcontext -from functools import partial - -import torch -from megatron.core import mpu -from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank -from megatron.training import get_args, get_model, training -from megatron.training.checkpointing import load_checkpoint -from megatron.training.utils import unwrap_model -from torch.distributed.nn import all_reduce - -from swift.utils import get_current_device, get_logger -from .trainer import MegatronTrainer -from .utils import get_batch, load_megatron_model_to_gpu, offload_megatron_model_to_cpu - -logger = get_logger() - - -class MegatronRLHFTrainer(MegatronTrainer): - - @contextmanager - def _get_iters(self, train_dataset, val_dataset): - origin_initialize_megatron = training.initialize_megatron - - def initialize_megatron(*_args, **kwargs): - res = origin_initialize_megatron(*_args, **kwargs) - args = get_args() - data_parallel_size = mpu.get_data_parallel_world_size() - step_batch_size = args.micro_batch_size * data_parallel_size - if args.train_iters is None and args.max_epochs is not None: - if hasattr(train_dataset, '__len__'): - dataset_sample = len(train_dataset) // step_batch_size * step_batch_size * args.num_generations - args.train_iters = dataset_sample * args.max_epochs // args.global_batch_size - else: - raise ValueError( - 'You are using a streaming training dataset. Please explicitly specify `--train_iters`.') - if args.eval_iters < 0: - if val_dataset is None: - args.eval_iters = 0 - elif hasattr(val_dataset, '__len__'): - dataset_sample = len(val_dataset) // step_batch_size * step_batch_size - args.eval_iters = max(dataset_sample // args.global_batch_size, 1) - else: - raise ValueError( - 'You are using a streaming validation dataset. Please explicitly specify `--eval_iters`.') - return res - - training.initialize_megatron = initialize_megatron - try: - yield - finally: - training.initialize_megatron = origin_initialize_megatron - - def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs): - args = get_args() - if args.train_type == 'full': - ref_model = get_model(model_provider_func, model_type) - if args.ref_load is None: - args.ref_load = args.load - args.iteration, args.num_floating_point_operations_so_far = load_checkpoint( - ref_model, None, None, load_arg='ref_load') - self.ref_model = ref_model[0] - self.ref_model.eval() - else: - self.ref_model = None - return super().setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs) - - @staticmethod - def _forward_step_helper(model, inputs): - args = get_args() - if mpu.is_pipeline_first_stage(): - micro_batch_size = 1 # use qkv_format 'thd' - seq_length = inputs['input_ids'].shape[1] - if args.sequence_parallel: - seq_length //= mpu.get_tensor_model_parallel_world_size() - recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size], - device=torch.cuda.current_device(), - dtype=torch.int64) - else: - recv_shape_buffer = torch.empty((3, ), device=torch.cuda.current_device(), dtype=torch.int64) - recv_from_prev_pipeline_rank_(recv_shape_buffer) - if not mpu.is_pipeline_last_stage(): - send_to_next_pipeline_rank(recv_shape_buffer) - shape = recv_shape_buffer.tolist() - - if not mpu.is_pipeline_first_stage(): - recv_buffer = torch.empty(shape, device=torch.cuda.current_device(), dtype=args.params_dtype) - recv_from_prev_pipeline_rank_(recv_buffer) - model.set_input_tensor(recv_buffer) - output_tensor = model(**inputs) - if not mpu.is_pipeline_last_stage(): - send_to_next_pipeline_rank(output_tensor) - output_tensor = None - - return output_tensor - - def model_forward(self, model, data_iterator, no_grad=True, per_token=False): - # used to calculate model forward (logps) - with self.stimer(bdata=True): - data = get_batch(data_iterator) - data.pop('loss_scale', None) - labels = data.get('labels') - context = torch.no_grad() if no_grad else nullcontext() - with context: - output_tensor = self._forward_step_helper(model, data) - data['logps'] = None if labels is None else self.get_logps( - output_tensor, labels, data['packed_seq_params'], per_token=per_token) - return data - - @staticmethod - def get_logps(output_tensor, labels, packed_seq_params, per_token: bool = False): - args = get_args() - per_token_logps = -output_tensor - loss_mask = labels != -100 - per_token_logps = per_token_logps * loss_mask - num_samples = packed_seq_params.num_samples - if args.rlhf_type == 'dpo': - total_samples = num_samples * 2 - elif args.rlhf_type in 'grpo': - total_samples = num_samples - - cu_seqlens = packed_seq_params.cu_seqlens_q[:total_samples + 1] // args.context_parallel_size - - if per_token: - if args.context_parallel_size > 1: - per_token_logps = all_reduce(per_token_logps, group=mpu.get_context_parallel_group()) - return per_token_logps - else: - all_logps = per_token_logps.new_zeros((total_samples, )) - for i in range(total_samples): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - all_logps[i] = per_token_logps[:, start:end].sum() - if args.context_parallel_size > 1: - all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) - return all_logps - - @contextmanager - def null_ref_context(self): - args = get_args() - if args.train_type == 'full': - context = nullcontext() - ref_model = unwrap_model(self.ref_model) - else: - if args.ref_adapter_load is None: - context = self.peft_model.disable_adapter() - else: - context = nullcontext() - ref_model = self.unwrapped_model - with context: - if args.ref_adapter_load: - self.peft_model.set_adapter('ref_adapter') - yield ref_model - if args.ref_adapter_load: - self.peft_model.set_adapter('default') - - @contextmanager - def offload_context(self, model): - # TODO: offload - yield From 6b1653ccf142bd12491cfa7acda3f4b17896ca31 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 8 Oct 2025 20:47:32 +0800 Subject: [PATCH 22/83] fix --- swift/megatron/trainers/base.py | 13 ---------- swift/megatron/trainers/grpo_trainer.py | 34 +++++++++++++++---------- 2 files changed, 21 insertions(+), 26 deletions(-) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 0a9c403362..7adee29140 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -872,19 +872,6 @@ def _forward_step_helper(model, inputs): return output_tensor - def model_forward(self, model, data_iterator, no_grad=True, per_token=False): - # used to calculate model forward (logps) - with self.stimer(bdata=True): - data = get_batch(data_iterator) - data.pop('loss_scale', None) - labels = data.get('labels') - context = torch.no_grad() if no_grad else nullcontext() - with context: - output_tensor = self._forward_step_helper(model, data) - data['logps'] = None if labels is None else self.get_logps( - output_tensor, labels, data['packed_seq_params'], per_token=per_token) - return data - @staticmethod def get_logps(output_tensor, labels, packed_seq_params, per_token: bool = False): args = get_args() diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 4dbf581d38..1808c3a64e 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -169,7 +169,7 @@ def _move_model_to_vllm(self): if self.bridge is None: self.bridge = AutoBridge.from_pretrained(self.hf_model_dir) self._patch_mbridge(self.bridge) - per_tensor_params = self.bridge.export_weights([self.unwrapped_model]) + per_tensor_params = self.bridge.export_weights(self.unwrapped_models) self.engine.inner_model.load_weights(per_tensor_params) # TODO: check tensor_model_parallel def _prepare_rewards(self): @@ -560,7 +560,7 @@ def log_rewards_metrics(rewards: torch.Tensor, rewards_per_func_for_metrics: tor """Log reward statistics for monitoring. Only log once per unique request_id.""" # rewards: [prompt_batch_size, self.num_generations] # rewards_per_func_for_metrics: [prompt_batch_size*self.num_generations, self.num_reward_funcs] - mode = 'train' if self.unwrapped_model.training else 'eval' + mode = 'train' if self.unwrapped_models[0].training else 'eval' group_rewards = rewards.view(-1, self.num_generations) rewards_mean = group_rewards.mean(-1).mean().item() rewards_std = group_rewards.std(-1).mean().item() @@ -587,19 +587,16 @@ def log_rewards_metrics(rewards: torch.Tensor, rewards_per_func_for_metrics: tor def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: # TODO: entropy inputs = {k: v for k, v in batch.items() if k not in ['completion_mask', 'advantages', 'truncated_mask']} - if self.beta != 0.0: - with torch.no_grad(), self.null_ref_context() as ref_model: + with torch.no_grad(), self.null_ref_context() as ref_models: + assert len(ref_models) == 1, 'KTO currently does not support VPP.' + ref_model = ref_models[0] batch['ref_per_token_logps'] = self.model_forward( ref_model, iter([inputs]), no_grad=True, per_token=True)['logps'] if not self.on_policy: batch['old_per_token_logps'] = self.model_forward( - self.unwrapped_model, iter([inputs]), no_grad=True, per_token=True)['logps'] - - # get packed_seq_params, from get_batch func - # batch = process_packed_seq_params(batch) - + self.unwrapped_models[0], iter([inputs]), no_grad=True, per_token=True)['logps'] return batch @contextmanager @@ -685,8 +682,8 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): truncated_mask = torch.repeat_interleave(truncated_mask, lengths).unsqueeze(0) padding_length = completion_mask.shape[1] - truncated_mask.shape[1] if padding_length > 0: - padding = torch.zeros(padding_length, device=truncated_mask.device, dtype=truncated_mask.dtype) - truncated_mask = torch.cat([truncated_mask, padding]) + padding = torch.zeros((1, padding_length), device=truncated_mask.device, dtype=truncated_mask.dtype) + truncated_mask = torch.cat([truncated_mask, padding], dim=1) completion_mask = completion_mask & (~truncated_mask) if self.beta != 0.0: @@ -737,8 +734,6 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): mask_list = torch.split(completion_mask.squeeze(0), lengths_with_padding.tolist()) sample_loss = [(loss * mask).sum() / mask.sum().clamp(min=1.0) for loss, mask in zip(loss_list, mask_list)] loss = torch.stack(sample_loss[:micro_batch_size]).mean() - - loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() elif self.loss_type == 'bnpo': loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) elif self.loss_type == 'dr_grpo': @@ -782,3 +777,16 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.12.0/megatron/core/pipeline_parallel/schedules.py#L291 loss = loss / mpu.get_context_parallel_world_size() return loss, reporting_metric + + def model_forward(self, model, data_iterator, no_grad=True, per_token=False): + # used to calculate model forward (logps) in GRPO + with self.stimer(bdata=True): + data = get_batch(data_iterator) + data.pop('loss_scale', None) + labels = data.get('labels') + context = torch.no_grad() if no_grad else nullcontext() + with context: + output_tensor = self._forward_step_helper(model, data) + data['logps'] = None if labels is None else self.get_logps( + output_tensor, labels, data['packed_seq_params'], per_token=per_token) + return data From d4a9dcc2d74c8266cfd4d64c896c2741eb2ead6e Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 8 Oct 2025 21:02:42 +0800 Subject: [PATCH 23/83] offload utils --- swift/megatron/trainers/grpo_trainer.py | 7 +++ swift/megatron/trainers/utils.py | 68 ++++++++++++++++++++++++- 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 1808c3a64e..1754554b8d 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -380,9 +380,16 @@ def _generate_completions(self, batch): and 'tags' in inspect.signature(self.engine.engine.wake_up).parameters): self.engine.engine.wake_up(tags=['kv_cache']) + # Step3: Rollout batch = self.preprocess_rollout_data(batch) outputs: List[RolloutOutput] = self._rollout(batch) + + # Step4: Sleep to release memory + if self.args.sleep_level > 0: + self.engine.engine.sleep(self.args.sleep_level) + batch = self.postprocess_rollout_data(batch, outputs) + return batch def preprocess_rollout_data(self, batch): diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index bf56904abe..1c07a1971c 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -9,6 +9,7 @@ from accelerate.utils import gather_object as hf_gather_object from megatron.core import mpu from megatron.core.distributed import DistributedDataParallel as DDP +from megatron.core.optimizer import ChainedOptimizer from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.utils import get_batch_on_this_cp_rank as mcore_get_batch_on_this_cp_rank from megatron.training import get_args, get_wandb_writer @@ -270,7 +271,7 @@ def gather_object(object: Any, group: Optional[torch.distributed.ProcessGroup] = raise -# code borrowed from VeRL +# code borrowed from verl @torch.no_grad() def load_megatron_model_to_gpu(models, load_grad=True): for model_chunk in models: @@ -332,3 +333,68 @@ def offload_megatron_model_to_cpu(models): param.grad = param.grad.to('cpu', non_blocking=True) gc.collect() empty_cache() + + +@torch.no_grad() +def load_megatron_copy_params(optimizers): + """ + Load optimizer parameters back to GPU. Handles ChainedOptimizer. + + Args: + optimizers: Optimizer or ChainedOptimizer instance. + """ + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + def load_tensor_to_gpu(tensor): + if tensor is None: + return + device_id = get_current_device() + tensor.data = tensor.data.to(device_id, non_blocking=True) + + def load_group_to_gpu(group): + if group is None: + return + + if isinstance(group, list): + for param_group in group: + if isinstance(param_group, list): + for param in param_group: + load_tensor_to_gpu(param) + else: + load_tensor_to_gpu(param_group) + else: + load_tensor_to_gpu(group) + + # Load all parameter groups to GPU for each underlying optimizer + + for _opt in _iter_opts(optimizers): + if hasattr(_opt, 'shard_fp32_from_float16_groups'): + load_group_to_gpu(_opt.shard_fp32_from_float16_groups) + + +@torch.no_grad() +def load_megatron_optimizer(optimizers): + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + for _opt in _iter_opts(optimizers): + load_megatron_copy_params(_opt) + # if we are using HybridDeviceOptimizer, we need to only move gpu optimizer state to gpu + if hasattr(_opt.optimizer, '_move_new_state_to_right_device'): + _opt.optimizer._move_new_state_to_right_device() + else: + opt_state_dict_values = _opt.optimizer.state.values() + for v in opt_state_dict_values: + if 'exp_avg' in v: + v['exp_avg'] = v['exp_avg'].to(get_current_device(), non_blocking=True) + if 'exp_avg_sq' in v: + v['exp_avg_sq'] = v['exp_avg_sq'].to(get_current_device(), non_blocking=True) + gc.collect() + empty_cache() From 9dc92a02a348d78c7c39a4ebd8ad79824ea2c522 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 9 Oct 2025 09:57:13 +0800 Subject: [PATCH 24/83] offload context --- swift/megatron/trainers/grpo_trainer.py | 71 +++++++++++++------ swift/megatron/trainers/utils.py | 93 +++++++++++++++++++------ 2 files changed, 119 insertions(+), 45 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 1754554b8d..37ae985720 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -23,7 +23,9 @@ from swift.utils import get_current_device, get_logger, is_vllm_available, remove_response from ..argument import MegatronArguments, MegatronRLHFArguments from .base import MegatronRLHFTrainer -from .utils import gather, gather_object, get_batch, process_packed_seq_params, profiling_context +from .utils import (gather, gather_object, get_batch, load_megatron_model_to_gpu, load_megatron_optimizer, + log_gpu_memory, offload_megatron_model_to_cpu, offload_megatron_optimizer, + process_packed_seq_params, profiling_context) try: from mbridge import AutoBridge @@ -133,6 +135,7 @@ def _prepare_rollout_engine(self): self.engine = self.prepare_vllm() if self.args.sleep_level > 0: self.engine.engine.sleep(self.args.sleep_level) + log_gpu_memory('after sleep vLLM engine') def prepare_vllm(self): from swift.llm.infer.infer_engine import GRPOVllmEngine @@ -367,28 +370,32 @@ def _generate_completions(self, batch): # TODO: server mode assert self.vllm_mode == 'colocate' # Step 1: Wake up the engine if it's sleeping (vLLM colocate mode) - if self.engine.inner_model_executor.is_sleeping: - wake_up_params = inspect.signature(self.engine.engine.wake_up).parameters - # Load weights only (faster and reduces memory peak) - kwargs = {'tags': ['weights']} if 'tags' in wake_up_params else {} - self.engine.engine.wake_up(**kwargs) - - # Step 2: Load model weights - self._move_model_to_vllm() - - if (self.engine.inner_model_executor.is_sleeping - and 'tags' in inspect.signature(self.engine.engine.wake_up).parameters): - self.engine.engine.wake_up(tags=['kv_cache']) - - # Step3: Rollout - batch = self.preprocess_rollout_data(batch) - outputs: List[RolloutOutput] = self._rollout(batch) - - # Step4: Sleep to release memory - if self.args.sleep_level > 0: - self.engine.engine.sleep(self.args.sleep_level) - - batch = self.postprocess_rollout_data(batch, outputs) + context = self.offload_context if self.enable_offload else nullcontext + with context(): + if self.engine.inner_model_executor.is_sleeping: + wake_up_params = inspect.signature(self.engine.engine.wake_up).parameters + # Load weights only (faster and reduces memory peak) + kwargs = {'tags': ['weights']} if 'tags' in wake_up_params else {} + self.engine.engine.wake_up(**kwargs) + log_gpu_memory(f'after wake up vLLM engine with {kwargs}') + + # Step 2: Load model weights + self._move_model_to_vllm() + + if (self.engine.inner_model_executor.is_sleeping + and 'tags' in inspect.signature(self.engine.engine.wake_up).parameters): + self.engine.engine.wake_up(tags=['kv_cache']) + log_gpu_memory('after wake up vLLM engine with kv_cache') + + # Step3: Rollout + batch = self.preprocess_rollout_data(batch) + outputs: List[RolloutOutput] = self._rollout(batch) + + # Step4: Sleep to release memory + if self.args.sleep_level > 0: + self.engine.engine.sleep(self.args.sleep_level) + log_gpu_memory('after sleep vLLM engine') + batch = self.postprocess_rollout_data(batch, outputs) return batch @@ -797,3 +804,21 @@ def model_forward(self, model, data_iterator, no_grad=True, per_token=False): data['logps'] = None if labels is None else self.get_logps( output_tensor, labels, data['packed_seq_params'], per_token=per_token) return data + + @contextmanager + def offload_context(self): + if self.args.offload_model: + offload_megatron_model_to_cpu(self.unwrapped_models) + log_gpu_memory('after offload model to cpu') + # if getattr(self, 'optimizer', None) and self.args.offload_optimizer: + # self.offload_optimizer() + + try: + yield + finally: + # reload (load back) model when exiting context + if self.args.offload_model: + load_megatron_model_to_gpu(self.unwrapped_models) + log_gpu_memory('after load model to gpu') + # if getattr(self, 'optimizer', None) and self.args.offload_optimizer: + # self.load_optimizer() diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 1c07a1971c..2adca0d252 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -16,6 +16,7 @@ from swift.llm import get_packed_seq_params as _get_packed_seq_params from swift.llm import to_device +from swift.utils import get_logger from swift.utils.torch_utils import empty_cache, get_current_device @@ -246,29 +247,9 @@ def gather_object(object: Any, group: Optional[torch.distributed.ProcessGroup] = if group is None: return hf_gather_object(object) size = torch.distributed.get_world_size(group=group) - rank = torch.distributed.get_rank(group=group) output_objects = [None for _ in range(size)] - - try: - # 添加调试信息 - from swift.utils import get_logger - logger = get_logger() - logger.info(f'Rank {rank}/{size} in group starting all_gather_object with {len(object)} objects') - - torch.distributed.all_gather_object(output_objects, object, group=group) - - logger.info(f'Rank {rank}/{size} in group completed all_gather_object successfully') - # flatten - return [x for y in output_objects for x in y] - except Exception as e: - from swift.utils import get_logger - logger = get_logger() - logger.error(f'Rank {rank}/{size} in group failed at all_gather_object: {e}') - logger.error(f"Object size: {len(object) if hasattr(object, '__len__') else 'unknown'}") - if torch.cuda.is_available(): - logger.error(f'GPU memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, ' - f'{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved') - raise + torch.distributed.all_gather_object(output_objects, object, group=group) + return [x for y in output_objects for x in y] # code borrowed from verl @@ -376,6 +357,47 @@ def load_group_to_gpu(group): load_group_to_gpu(_opt.shard_fp32_from_float16_groups) +@torch.no_grad() +def offload_megatron_copy_params(optimizers): + """ + Offload optimizer parameters to CPU. Supports both Megatron optimizers + and `ChainedOptimizer`, which wraps a list of underlying optimizers. + + Args: + optimizers: The optimizer or ChainedOptimizer instance. + """ + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + def offload_tensor_to_cpu(tensor): + if tensor is None: + return + tensor.data = tensor.data.to('cpu', non_blocking=True) + + def offload_group_to_cpu(group): + if group is None: + return + + if isinstance(group, list): + for param_group in group: + if isinstance(param_group, list): + for param in param_group: + offload_tensor_to_cpu(param) + else: + offload_tensor_to_cpu(param_group) + else: + offload_tensor_to_cpu(group) + + # Offload all parameter groups to CPU for each underlying optimizer + + for _opt in _iter_opts(optimizers): + if hasattr(_opt, 'shard_fp32_from_float16_groups'): + offload_group_to_cpu(_opt.shard_fp32_from_float16_groups) + + @torch.no_grad() def load_megatron_optimizer(optimizers): @@ -398,3 +420,30 @@ def _iter_opts(opt): v['exp_avg_sq'] = v['exp_avg_sq'].to(get_current_device(), non_blocking=True) gc.collect() empty_cache() + + +@torch.no_grad() +def offload_megatron_optimizer(optimizers): + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + for _opt in _iter_opts(optimizers): + offload_megatron_copy_params(_opt) + opt_state_dict_values = _opt.optimizer.state.values() + for v in opt_state_dict_values: + if 'exp_avg' in v: + v['exp_avg'] = v['exp_avg'].to('cpu', non_blocking=True) + if 'exp_avg_sq' in v: + v['exp_avg_sq'] = v['exp_avg_sq'].to('cpu', non_blocking=True) + gc.collect() + empty_cache() + + +def log_gpu_memory(prefix: str = ''): + logger = get_logger() + + logger.info(f'{prefix} GPU memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, ' + f'{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved') From 91f97ca9757479875b206e5a2ae2795635aa53f8 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 9 Oct 2025 10:43:02 +0800 Subject: [PATCH 25/83] fix resolve --- swift/megatron/argument/megatron_args.py | 2 -- swift/megatron/argument/train_args.py | 1 - swift/megatron/trainers/grpo_trainer.py | 19 ++++++++++--------- swift/megatron/trainers/rlhf_mixin.py | 7 +++++-- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index e246c5a81c..3672a49ef7 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -149,8 +149,6 @@ def __post_init__(self): self._init_kto() if self.rlhf_type == 'grpo': self._init_grpo() - super().__post_init__() - if self.rlhf_type == 'grpo': self._set_grpo_default() def _set_grpo_default(self): diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py index ff032260c2..7552e65a00 100644 --- a/swift/megatron/argument/train_args.py +++ b/swift/megatron/argument/train_args.py @@ -29,7 +29,6 @@ def init_model_args(self, tokenizer, config): if getattr(self, k) is None: setattr(self, k, v) MegatronArguments.__post_init__(self) - RLHFMegatronArgumentsMixin.__post_init__(self) self.extra_args = self.parse_to_megatron() self.extra_args['model_info'] = self.model_info self.extra_args['model_meta'] = self.model_meta diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 37ae985720..b01063c084 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -22,10 +22,9 @@ from swift.trainers.rlhf_trainer.utils import replace_assistant_response_with_ids from swift.utils import get_current_device, get_logger, is_vllm_available, remove_response from ..argument import MegatronArguments, MegatronRLHFArguments -from .base import MegatronRLHFTrainer -from .utils import (gather, gather_object, get_batch, load_megatron_model_to_gpu, load_megatron_optimizer, - log_gpu_memory, offload_megatron_model_to_cpu, offload_megatron_optimizer, - process_packed_seq_params, profiling_context) +from .rlhf_mixin import MegatronRLHFTrainer +from .utils import (gather, gather_object, load_megatron_model_to_gpu, load_megatron_optimizer, log_gpu_memory, + offload_megatron_model_to_cpu, offload_megatron_optimizer, profiling_context) try: from mbridge import AutoBridge @@ -235,7 +234,7 @@ def _weight_to_hf_format_patched(mcore_weights_name, mcore_weights): bridge._weight_to_hf_format = _weight_to_hf_format_patched - def _replace_data_iterator(self, data_iterator): + def _replace_data_iterator(self, data_iterator, model): if self._step % self.steps_per_generation == 0: # each rollout DP group will generate generation_batch_size / world_size completions @@ -667,7 +666,7 @@ def build_pretraining_data_loader(*_args, **kwargs): def forward_step(self, data_iterator, model): # train_batch_size # return: output_tensor, loss_func - data = get_batch(data_iterator) + data = self.get_batch(data_iterator) data.pop('loss_scale', None) inputs = { k: v @@ -689,7 +688,8 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): lengths = packed_seq_params.cu_seqlens_q[1:micro_batch_size + 1] - packed_seq_params.cu_seqlens_q[:micro_batch_size] lengths_with_padding = packed_seq_params.cu_seqlens_q[1:] - packed_seq_params.cu_seqlens_q[:-1] - per_token_logps = self.get_logps(output_tensor, labels, packed_seq_params, per_token=True) + per_token_logps = self.get_logps( + output_tensor, labels, packed_seq_params, packed_seq_params.num_samples, per_token=True) if self.args.overlong_filter and any(truncated_mask): # TODO: non-padding-free @@ -795,14 +795,15 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): def model_forward(self, model, data_iterator, no_grad=True, per_token=False): # used to calculate model forward (logps) in GRPO with self.stimer(bdata=True): - data = get_batch(data_iterator) + data = self.get_batch(data_iterator) data.pop('loss_scale', None) labels = data.get('labels') context = torch.no_grad() if no_grad else nullcontext() with context: output_tensor = self._forward_step_helper(model, data) + packed_seq_params = data['packed_seq_params'] data['logps'] = None if labels is None else self.get_logps( - output_tensor, labels, data['packed_seq_params'], per_token=per_token) + output_tensor, labels, data['packed_seq_params'], packed_seq_params.num_samples, per_token=per_token) return data @contextmanager diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py index ead111435e..69be62cd00 100644 --- a/swift/megatron/trainers/rlhf_mixin.py +++ b/swift/megatron/trainers/rlhf_mixin.py @@ -82,7 +82,7 @@ def _forward_step_helper(model, inputs): return output_tensor - def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None): + def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None, per_token=False): args = get_args() per_token_logps = -output_tensor loss_mask = labels != -100 @@ -93,7 +93,10 @@ def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None): all_logps = per_token_logps.new_zeros((num_samples, )) for i in range(num_samples): start, end = cu_seqlens[i], cu_seqlens[i + 1] - all_logps[i] = per_token_logps[:, start:end].sum() + if per_token: + all_logps[i] = per_token_logps[:, start:end] + else: + all_logps[i] = per_token_logps[:, start:end].sum() if args.context_parallel_size > 1: all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) return all_logps From 59f436c1923ac828fe7fddf6855b2334b2559cd8 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 9 Oct 2025 10:52:13 +0800 Subject: [PATCH 26/83] fix logps --- swift/megatron/trainers/grpo_trainer.py | 2 +- swift/megatron/trainers/rlhf_mixin.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index b01063c084..15f8da5c37 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -602,7 +602,7 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: inputs = {k: v for k, v in batch.items() if k not in ['completion_mask', 'advantages', 'truncated_mask']} if self.beta != 0.0: with torch.no_grad(), self.null_ref_context() as ref_models: - assert len(ref_models) == 1, 'KTO currently does not support VPP.' + assert len(ref_models) == 1, 'GRPO currently does not support VPP.' ref_model = ref_models[0] batch['ref_per_token_logps'] = self.model_forward( ref_model, iter([inputs]), no_grad=True, per_token=True)['logps'] diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py index 69be62cd00..7c937ab308 100644 --- a/swift/megatron/trainers/rlhf_mixin.py +++ b/swift/megatron/trainers/rlhf_mixin.py @@ -87,16 +87,18 @@ def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None, per_token_logps = -output_tensor loss_mask = labels != -100 per_token_logps = per_token_logps * loss_mask + if per_token: + if args.context_parallel_size > 1: + per_token_logps = all_reduce(per_token_logps, group=mpu.get_context_parallel_group()) + return per_token_logps + if num_samples is None: num_samples = packed_seq_params.num_samples * 2 cu_seqlens = packed_seq_params.cu_seqlens_q[:num_samples + 1] // args.context_parallel_size all_logps = per_token_logps.new_zeros((num_samples, )) for i in range(num_samples): start, end = cu_seqlens[i], cu_seqlens[i + 1] - if per_token: - all_logps[i] = per_token_logps[:, start:end] - else: - all_logps[i] = per_token_logps[:, start:end].sum() + all_logps[i] = per_token_logps[:, start:end].sum() if args.context_parallel_size > 1: all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) return all_logps From 8dea6d73908203e4cb4f5eb2a50510abf9594dec Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 9 Oct 2025 11:57:08 +0800 Subject: [PATCH 27/83] fix old logps --- swift/megatron/trainers/grpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 15f8da5c37..6fa2238d6d 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -605,11 +605,11 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: assert len(ref_models) == 1, 'GRPO currently does not support VPP.' ref_model = ref_models[0] batch['ref_per_token_logps'] = self.model_forward( - ref_model, iter([inputs]), no_grad=True, per_token=True)['logps'] + ref_model, iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps'] if not self.on_policy: batch['old_per_token_logps'] = self.model_forward( - self.unwrapped_models[0], iter([inputs]), no_grad=True, per_token=True)['logps'] + self.unwrapped_models[0], iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps'] return batch @contextmanager From abac6967c5dac4d14c3cca3c17a83248608a06bc Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 9 Oct 2025 15:32:26 +0800 Subject: [PATCH 28/83] reduce redundancy --- swift/megatron/argument/megatron_args.py | 33 +++--------------------- swift/megatron/trainers/base.py | 27 ------------------- swift/megatron/trainers/utils.py | 12 --------- 3 files changed, 4 insertions(+), 68 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 3672a49ef7..5040326307 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -39,6 +39,8 @@ class RLHFMegatronArgumentsMixin: steps_per_generation: Optional[int] = None num_generations: int = 8 max_completion_length: int = 512 + # GSPO https://www.arxiv.org/abs/2507.18071 + importance_sampling_level: Literal['token', 'sequence', 'sequence_token'] = 'token' # ─────────────────────────── Sampling ─────────────────────────── epsilon: float = 0.2 @@ -60,7 +62,7 @@ class RLHFMegatronArgumentsMixin: vllm_disable_cascade_attn: bool = False sleep_level: Literal[0, 1, 2] = 0 - # ────────────── External VLLM (server) ────────────── + # ────────────── External VLLM (server, not supported yet) ────────────── vllm_server_base_url: Optional[List[str]] = None vllm_server_host: Optional[List[str]] = None vllm_server_port: List[int] = field(default_factory=lambda: [8000]) @@ -84,9 +86,6 @@ class RLHFMegatronArgumentsMixin: soft_max_length: Optional[int] = None soft_cache_length: Optional[int] = None - reward_model: Optional[List[str]] = None - reward_model_plugin: Optional[List[str]] = None - # ─────────────────────────── Not Supported Yet ─────────────────────────── # reward model reward_model: Optional[List[str]] = None @@ -123,9 +122,6 @@ class RLHFMegatronArgumentsMixin: # Beyond the 80/20 Rule, https://arxiv.org/abs/2506.01939 top_entropy_quantile: float = 1.0 - # GSPO https://www.arxiv.org/abs/2507.18071 - importance_sampling_level: Literal['token', 'sequence', 'sequence_token'] = 'token' - wandb_log_unique_prompts: Optional[bool] = None num_iterations: int = 1 @@ -149,11 +145,6 @@ def __post_init__(self): self._init_kto() if self.rlhf_type == 'grpo': self._init_grpo() - self._set_grpo_default() - - def _set_grpo_default(self): - if self.beta is None: - self.beta = 0.04 # https://arxiv.org/abs/2402.03300 def _init_grpo(self): @@ -172,23 +163,9 @@ def _init_external_vllm(): logger.info('Connected to vLLM server') def _check_not_supported(): - # TODO: check - # bool - not_supported_args = [ - 'sync_ref_model', - 'async_generate', - ] - for arg in not_supported_args: - if getattr(self, arg): - raise ValueError(f'{arg} is not supported for Megatron-GRPO yet, please unset it.') - # else - if self.num_iterations > 1: - raise ValueError('num_iterations > 1 is not supported for Megatron-GRPO yet, please set it to 1.') + pass def _check_batch_params(): - # assert self.micro_batch_size % self.num_generations == 0, \ - # f'micro_batch_size ({self.micro_batch_size}) must be divisible' \ - # f' by the number of generations ({self.num_generations})' if self.generation_batch_size is None and self.steps_per_generation is None: self.steps_per_generation = 1 self.generation_batch_size = self.global_batch_size * self.steps_per_generation @@ -213,8 +190,6 @@ def _check_batch_params(): _check_not_supported() _check_batch_params() # default loss_type if no loss_type is provided - if self.loss_type == 'sigmoid': - self.loss_type = 'grpo' assert self.loss_type in ['grpo', 'bnpo', 'dr_grpo'], \ f'loss_type must be one of [grpo, bnpo, dr_grpo], but got {self.loss_type}' if self.async_generate or not self.use_vllm: diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index da6d47e454..c0bb9bf5ec 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -827,30 +827,3 @@ def get_batch(self, data_iterator, vp_stage=None): if is_finished: args.train_iters = args.curr_iteration + 1 return self._prepare_batch(data, vp_stage) - - @staticmethod - def get_logps(output_tensor, labels, packed_seq_params, per_token: bool = False): - args = get_args() - per_token_logps = -output_tensor - loss_mask = labels != -100 - per_token_logps = per_token_logps * loss_mask - num_samples = packed_seq_params.num_samples - if args.rlhf_type == 'dpo': - total_samples = num_samples * 2 - elif args.rlhf_type in 'grpo': - total_samples = num_samples - - cu_seqlens = packed_seq_params.cu_seqlens_q[:total_samples + 1] // args.context_parallel_size - - if per_token: - if args.context_parallel_size > 1: - per_token_logps = all_reduce(per_token_logps, group=mpu.get_context_parallel_group()) - return per_token_logps - else: - all_logps = per_token_logps.new_zeros((total_samples, )) - for i in range(total_samples): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - all_logps[i] = per_token_logps[:, start:end].sum() - if args.context_parallel_size > 1: - all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) - return all_logps diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 2dfeb7014c..0adb9d6930 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -65,18 +65,6 @@ def get_packed_seq_params(position_ids: torch.Tensor) -> PackedSeqParams: qkv_format='thd') -def process_packed_seq_params(batch: Dict[str, Any]) -> int: - args = get_args() - num_samples = batch.pop('num_samples') - text_position_ids = batch.pop('text_position_ids', None) - if text_position_ids is None: - text_position_ids = batch.get('position_ids') - if args.padding_free and text_position_ids is not None: - batch['packed_seq_params'] = get_packed_seq_params(text_position_ids) - batch['packed_seq_params'].num_samples = num_samples - return batch - - def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: torch.Tensor, dim: int): if dim < 0: dim = (dim + inputs.ndim) % inputs.ndim From 3a3ff37c6b512867e1460bfb9c256987d4340884 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 10 Oct 2025 10:34:51 +0800 Subject: [PATCH 29/83] replace token --- swift/megatron/train/rlhf.py | 1 - swift/megatron/trainers/grpo_trainer.py | 88 +++++++++++++------------ 2 files changed, 47 insertions(+), 42 deletions(-) diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index e74a80030f..5b133bdcde 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -34,7 +34,6 @@ def _prepare_template(self) -> None: def _get_data_collator(self): if self.args.rlhf_type == 'grpo': - super()._get_data_collator() return identity_data_collator return super()._get_data_collator() diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 6fa2238d6d..36b8a92ab8 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -172,7 +172,7 @@ def _move_model_to_vllm(self): self.bridge = AutoBridge.from_pretrained(self.hf_model_dir) self._patch_mbridge(self.bridge) per_tensor_params = self.bridge.export_weights(self.unwrapped_models) - self.engine.inner_model.load_weights(per_tensor_params) # TODO: check tensor_model_parallel + self.engine.inner_model.load_weights(per_tensor_params) def _prepare_rewards(self): # TODO: reward model @@ -300,31 +300,31 @@ def _get_encoded_batch(rollout_batch, advantages): labels = encoded_batch['labels'] # TODO: logits_to_keep # logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item() - if self.template.padding_free: - position_ids = encoded_batch.get('text_position_ids') - if position_ids is None: - position_ids = encoded_batch.get('position_ids') - squeezed_position_ids = position_ids.squeeze() - assert squeezed_position_ids is not None - # Remove trailing padding zeros from position_ids to avoid interference - # Find the last non-zero position - last_nonzero_idx = (squeezed_position_ids != 0).nonzero(as_tuple=True)[0] - if len(last_nonzero_idx) > 0: - # Keep only up to the last non-zero position + 1 to include the last valid position - squeezed_position_ids = squeezed_position_ids[:last_nonzero_idx[-1] + 1] - - # Calculate lengths based on sequence boundaries (position_ids == 0) - lengths = torch.diff( - torch.cat([(squeezed_position_ids == 0).nonzero(as_tuple=True)[0], - torch.tensor([len(squeezed_position_ids)]).to(squeezed_position_ids.device)])) - advantages = torch.repeat_interleave(advantages, lengths) - - # Pad advantages to match the original position_ids length - original_length = position_ids.shape[1] - if advantages.shape[0] < original_length: - padding_length = original_length - advantages.shape[0] - padding = torch.zeros(padding_length, device=advantages.device, dtype=advantages.dtype) - advantages = torch.cat([advantages, padding]) + assert self.template.padding_free + position_ids = encoded_batch.get('text_position_ids') + if position_ids is None: + position_ids = encoded_batch.get('position_ids') + squeezed_position_ids = position_ids.squeeze() + assert squeezed_position_ids is not None + # Remove trailing padding zeros from position_ids to avoid interference + # Find the last non-zero position + last_nonzero_idx = (squeezed_position_ids != 0).nonzero(as_tuple=True)[0] + if len(last_nonzero_idx) > 0: + # Keep only up to the last non-zero position + 1 to include the last valid position + squeezed_position_ids = squeezed_position_ids[:last_nonzero_idx[-1] + 1] + + # Calculate lengths based on sequence boundaries (position_ids == 0) + lengths = torch.diff( + torch.cat([(squeezed_position_ids == 0).nonzero(as_tuple=True)[0], + torch.tensor([len(squeezed_position_ids)]).to(squeezed_position_ids.device)])) + advantages = torch.repeat_interleave(advantages, lengths) + + # Pad advantages to match the original position_ids length + original_length = position_ids.shape[1] + if advantages.shape[0] < original_length: + padding_length = original_length - advantages.shape[0] + padding = torch.zeros(padding_length, device=advantages.device, dtype=advantages.dtype) + advantages = torch.cat([advantages, padding]) encoded_batch.update({ 'completion_mask': @@ -338,13 +338,14 @@ def _get_encoded_batch(rollout_batch, advantages): return encoded_batch # Step2: ref/old logps - rollout_group total_batch = gather_object(rollout_batch, group=rollout_group) total_advantages = gather(advantages, group=rollout_group) mini_batch_data = [] for idx in range(0, len(total_batch), self.micro_batch_size): - micro_batch_data = _get_encoded_batch(total_batch[idx:idx + self.micro_batch_size], - total_advantages[idx:idx + self.micro_batch_size]) + micro_batch_data = total_batch[idx:idx + self.micro_batch_size] + micro_batch_data = self._maybe_replace_response_token(micro_batch_data) + micro_batch_advantages = total_advantages[idx:idx + self.micro_batch_size] + micro_batch_data = _get_encoded_batch(micro_batch_data, micro_batch_advantages) micro_batch_data = self._maybe_compute_logps(micro_batch_data) mini_batch_data.append(micro_batch_data) @@ -369,18 +370,18 @@ def _generate_completions(self, batch): # TODO: server mode assert self.vllm_mode == 'colocate' # Step 1: Wake up the engine if it's sleeping (vLLM colocate mode) - context = self.offload_context if self.enable_offload else nullcontext - with context(): - if self.engine.inner_model_executor.is_sleeping: - wake_up_params = inspect.signature(self.engine.engine.wake_up).parameters - # Load weights only (faster and reduces memory peak) - kwargs = {'tags': ['weights']} if 'tags' in wake_up_params else {} - self.engine.engine.wake_up(**kwargs) - log_gpu_memory(f'after wake up vLLM engine with {kwargs}') + if self.engine.inner_model_executor.is_sleeping: + wake_up_params = inspect.signature(self.engine.engine.wake_up).parameters + # Load weights only (faster and reduces memory peak) + kwargs = {'tags': ['weights']} if 'tags' in wake_up_params else {} + self.engine.engine.wake_up(**kwargs) + log_gpu_memory(f'after wake up vLLM engine with {kwargs}') - # Step 2: Load model weights - self._move_model_to_vllm() + # Step 2: Load model weights + self._move_model_to_vllm() + context = self.offload_context if self.enable_offload else nullcontext + with context(): if (self.engine.inner_model_executor.is_sleeping and 'tags' in inspect.signature(self.engine.engine.wake_up).parameters): self.engine.engine.wake_up(tags=['kv_cache']) @@ -555,6 +556,7 @@ def maybe_normalize_advantages(advantages: torch.Tensor, rewards_std: torch.Tens return advantages / (rewards_std + 1e-4) return advantages + assert len(batch) == rewards_per_func.shape[0] total_rewards_per_func = gather(rewards_per_func) rewards = (total_rewards_per_func * self.reward_weights.unsqueeze(0)).nansum(dim=1) grouped_rewards = rewards.view(-1, self.num_generations) @@ -624,6 +626,10 @@ def _disable_maxlength_template_context(self, template: Template): def _maybe_replace_response_token(self, batch): # maybe replace the response token with the response token ids to avoid repetitive tokenize + + # ignore when loss_scale is set + if self.template.loss_scale.name != 'last_round': + return batch for data in batch: if 'response_token_ids' in data and data['response_token_ids']: loss_mask = None @@ -711,7 +717,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): if self.importance_sampling_level == 'token': log_importance_weights = log_ratio - elif self.importance_sampling_level == 'sequence': + elif self.importance_sampling_level in ['sequence', 'sequence_token']: log_ratio_list = torch.split(log_ratio.squeeze(0), lengths_with_padding.tolist()) mask_list = torch.split(completion_mask.squeeze(0), lengths_with_padding.tolist()) seq_weights = [(lr * m).sum() / m.sum().clamp(min=1.0) for lr, m in zip(log_ratio_list, mask_list)] @@ -755,7 +761,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): else: raise ValueError(f'Unknown loss type: {self.loss_type}') - loss = loss.mean() + # loss = loss.mean() avg_metric = { 'loss': loss.clone().detach(), 'completions/mean_length': lengths.float().mean(), From 2cd89dc813b8f16c76354ce7da42a549ef7fdb7e Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 10 Oct 2025 20:27:21 +0800 Subject: [PATCH 30/83] fix offload model --- swift/megatron/trainers/base.py | 2 ++ swift/megatron/trainers/grpo_trainer.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index c0bb9bf5ec..05f8152b08 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -48,6 +48,7 @@ def __init__(self, args, template: Template): self.template = template self.stimer = StragglerDetector() self.unwrapped_models = [] + self.wrapped_models = [] self.peft_models = [] logging_path = os.path.join(args.save, 'logging.jsonl') logger.info(f'logging_path: {logging_path}') @@ -266,6 +267,7 @@ def new_model_provider_func(*args, **kwargs): with self._patch_load_state_dict(self._load_base_checkpoint): model, optimizer, opt_param_scheduler = self._origin_setup_model_and_optimizer( new_model_provider_func, model_type, *_args, **kwargs) + self.wrapped_models = model if args.initialize_embedding: for m in self.unwrapped_models: self._initialize_embedding(m) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 36b8a92ab8..82a3888040 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -815,7 +815,7 @@ def model_forward(self, model, data_iterator, no_grad=True, per_token=False): @contextmanager def offload_context(self): if self.args.offload_model: - offload_megatron_model_to_cpu(self.unwrapped_models) + offload_megatron_model_to_cpu(self.wrapped_models) log_gpu_memory('after offload model to cpu') # if getattr(self, 'optimizer', None) and self.args.offload_optimizer: # self.offload_optimizer() @@ -825,7 +825,7 @@ def offload_context(self): finally: # reload (load back) model when exiting context if self.args.offload_model: - load_megatron_model_to_gpu(self.unwrapped_models) + load_megatron_model_to_gpu(self.wrapped_models) log_gpu_memory('after load model to gpu') # if getattr(self, 'optimizer', None) and self.args.offload_optimizer: # self.load_optimizer() From 50d5e6f65afc2457c061b80c00a8657b9d35536c Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Sat, 11 Oct 2025 10:22:50 +0800 Subject: [PATCH 31/83] offload optimizer & ref --- swift/megatron/trainers/grpo_trainer.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 82a3888040..3dfd6d0045 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -816,9 +816,12 @@ def model_forward(self, model, data_iterator, no_grad=True, per_token=False): def offload_context(self): if self.args.offload_model: offload_megatron_model_to_cpu(self.wrapped_models) + if hasattr(self, 'ref_models') and self.ref_models: + offload_megatron_model_to_cpu(self.ref_models) log_gpu_memory('after offload model to cpu') - # if getattr(self, 'optimizer', None) and self.args.offload_optimizer: - # self.offload_optimizer() + if getattr(self, 'optimizer', None) and self.args.offload_optimizer: + offload_megatron_optimizer(self.optimizer) + log_gpu_memory('after offload optimizer to cpu') try: yield @@ -826,6 +829,9 @@ def offload_context(self): # reload (load back) model when exiting context if self.args.offload_model: load_megatron_model_to_gpu(self.wrapped_models) + if hasattr(self, 'ref_models') and self.ref_models: + load_megatron_model_to_gpu(self.ref_models) log_gpu_memory('after load model to gpu') - # if getattr(self, 'optimizer', None) and self.args.offload_optimizer: - # self.load_optimizer() + if getattr(self, 'optimizer', None) and self.args.offload_optimizer: + load_megatron_optimizer(self.optimizer) + log_gpu_memory('after load optimizer to gpu') From e1a06c658c5c7e967a2b29c7c2aed5570620fa16 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Sat, 11 Oct 2025 15:42:25 +0800 Subject: [PATCH 32/83] support cp --- swift/llm/template/base.py | 2 + swift/megatron/trainers/grpo_trainer.py | 106 ++++++++++++++++++++---- swift/megatron/trainers/utils.py | 7 +- 3 files changed, 98 insertions(+), 17 deletions(-) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 78291888d2..f09c208afe 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1266,6 +1266,8 @@ def _handle_megatron_cp(self, encoded: Dict[str, Any]) -> None: cp_size = self.sequence_parallel_size if not self.use_megatron or cp_size == 1: return + if self.mode == 'vllm': # skip for megatron grpo rollout + return input_ids = encoded['input_ids'] padding_len = math.ceil(len(input_ids) / (cp_size * 2)) * (cp_size * 2) - len(input_ids) input_ids += [self.tokenizer.pad_token_id] * padding_len diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 3dfd6d0045..4b05ae661d 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -24,7 +24,7 @@ from ..argument import MegatronArguments, MegatronRLHFArguments from .rlhf_mixin import MegatronRLHFTrainer from .utils import (gather, gather_object, load_megatron_model_to_gpu, load_megatron_optimizer, log_gpu_memory, - offload_megatron_model_to_cpu, offload_megatron_optimizer, profiling_context) + offload_megatron_model_to_cpu, offload_megatron_optimizer, profiling_context, split_cp_inputs) try: from mbridge import AutoBridge @@ -102,6 +102,7 @@ def _init_grpo_params(self): return_details=True) self._step = 0 + self._rollout_group = None # Will be lazily initialized def _prepare_rollout_engine(self): args = self.args @@ -234,6 +235,62 @@ def _weight_to_hf_format_patched(mcore_weights_name, mcore_weights): bridge._weight_to_hf_format = _weight_to_hf_format_patched + def _get_rollout_group(self): + """ + Get or create the rollout process group (TP×PP×CP). + + The rollout group is used for: + 1. Data slicing: distributing rollout data across all model parallel ranks (including CP) + 2. Gather operations: collecting results from all model parallel ranks (including CP) + + Note: MODEL_PARALLEL_GROUP only includes TP×PP, but we need TP×PP×CP for correct + data distribution during rollout phase. + + Key insight: ranks with the same DP index but different TP/PP/CP indices should be + in the same rollout group. These ranks will: + - During rollout: each process different data slices + - During training: TP/PP ranks process same data (model split), CP ranks process same data (sequence split) + - During gather: collect all data from TP×PP×CP ranks for training + """ + if self._rollout_group is not None: + return self._rollout_group + + cp_size = mpu.get_context_parallel_world_size() + if cp_size == 1: + # No CP, use the standard MODEL_PARALLEL_GROUP + self._rollout_group = mpu.get_model_parallel_group() + return self._rollout_group + + # Get parallel dimensions + tp_size = mpu.get_tensor_model_parallel_world_size() + pp_size = mpu.get_pipeline_model_parallel_world_size() + dp_size = mpu.get_data_parallel_world_size() + global_rank = torch.distributed.get_rank() + + # Calculate rollout group size + rollout_group_size = tp_size * pp_size * cp_size + + # Simple and reliable method: assume ranks are organized in contiguous blocks per DP group + # This is typically true for the default order (tp-cp-ep-dp-pp) + # Each DP group has rollout_group_size consecutive ranks + ranks_per_dp_group = rollout_group_size + my_dp_block_index = global_rank // ranks_per_dp_group + + # Calculate the rank range for my rollout group + group_start = my_dp_block_index * ranks_per_dp_group + + # Create all rollout groups (must be done on all ranks) + if not hasattr(self, '_rollout_groups_created'): + for dp_idx in range(dp_size): + group_start = dp_idx * ranks_per_dp_group + group_ranks = list(range(group_start, min(group_start + ranks_per_dp_group, self.world_size))) + group = torch.distributed.new_group(ranks=group_ranks, group_desc='ROLLOUT_GROUP') + if global_rank in group_ranks: + self._rollout_group = group + self._rollout_groups_created = True + + return self._rollout_group + def _replace_data_iterator(self, data_iterator, model): if self._step % self.steps_per_generation == 0: @@ -268,17 +325,28 @@ def _replace_data_iterator(self, data_iterator, model): return iter(inputs) def _generate_and_score_completions(self, batch): - rollout_group = mpu.get_model_parallel_group() + # Get or create the rollout group (TP×PP×CP) + # We need to include CP in the rollout group because: + # 1. Data slicing: each rank (including CP ranks) should get different data + # 2. Gather operations: we need to gather data from all TP×PP×CP ranks + rollout_group = self._get_rollout_group() # batch : same across DP groups def get_local_rollout_batch(batch): # repeat num_generations times global_rollout_batch = [deepcopy(item) for item in batch for _ in range(self.num_generations)] # get local rollout data + # For rollout, we need to distribute data across TP×PP×CP ranks + # Note: During rollout (vLLM inference), each GPU processes different data + # CP will only take effect during training (forward/backward) rollout_rank = torch.distributed.get_rank(group=rollout_group) rollout_group_size = torch.distributed.get_world_size(group=rollout_group) + per_device_batch_size = self.per_device_generation_batch_size - assert rollout_group_size * per_device_batch_size == len(global_rollout_batch) + assert rollout_group_size * per_device_batch_size == len(global_rollout_batch), ( + f'rollout_group_size ({rollout_group_size}) * per_device_batch_size ({per_device_batch_size}) ' + f'!= len(global_rollout_batch) ({len(global_rollout_batch)}). ' + f'rollout_rank={rollout_rank}') data_slice = slice(rollout_rank * per_device_batch_size, (rollout_rank + 1) * per_device_batch_size) rollout_batch = global_rollout_batch[data_slice] return rollout_batch @@ -318,7 +386,14 @@ def _get_encoded_batch(rollout_batch, advantages): torch.cat([(squeezed_position_ids == 0).nonzero(as_tuple=True)[0], torch.tensor([len(squeezed_position_ids)]).to(squeezed_position_ids.device)])) advantages = torch.repeat_interleave(advantages, lengths) - + truncated_mask = torch.tensor([b['is_truncated'] for b in rollout_batch], + dtype=torch.bool, + device=self.device) + truncated_mask = torch.repeat_interleave(truncated_mask, lengths).unsqueeze(0) + padding_length = labels.shape[1] - truncated_mask.shape[1] + if padding_length > 0: + padding = torch.zeros((1, padding_length), device=truncated_mask.device, dtype=truncated_mask.dtype) + truncated_mask = torch.cat([truncated_mask, padding], dim=1) # Pad advantages to match the original position_ids length original_length = position_ids.shape[1] if advantages.shape[0] < original_length: @@ -327,12 +402,9 @@ def _get_encoded_batch(rollout_batch, advantages): advantages = torch.cat([advantages, padding]) encoded_batch.update({ - 'completion_mask': - labels != -100, - 'truncated_mask': - torch.tensor([b['is_truncated'] for b in rollout_batch], dtype=torch.bool, device=self.device), - 'advantages': - advantages, + 'completion_mask': labels != -100, + 'truncated_mask': truncated_mask, + 'advantages': advantages, }) return encoded_batch @@ -694,16 +766,18 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): lengths = packed_seq_params.cu_seqlens_q[1:micro_batch_size + 1] - packed_seq_params.cu_seqlens_q[:micro_batch_size] lengths_with_padding = packed_seq_params.cu_seqlens_q[1:] - packed_seq_params.cu_seqlens_q[:-1] + if mpu.get_context_parallel_world_size() > 1: + # When using Context Parallel, each rank only processes a portion of the sequence + # So we need to divide the lengths by CP size + cp_size = mpu.get_context_parallel_world_size() + cu_seqlens_cp = packed_seq_params.cu_seqlens_q // cp_size + lengths_with_padding = cu_seqlens_cp[1:] - cu_seqlens_cp[:-1] + lengths = cu_seqlens_cp[1:micro_batch_size + 1] - cu_seqlens_cp[:micro_batch_size] + per_token_logps = self.get_logps( output_tensor, labels, packed_seq_params, packed_seq_params.num_samples, per_token=True) if self.args.overlong_filter and any(truncated_mask): - # TODO: non-padding-free - truncated_mask = torch.repeat_interleave(truncated_mask, lengths).unsqueeze(0) - padding_length = completion_mask.shape[1] - truncated_mask.shape[1] - if padding_length > 0: - padding = torch.zeros((1, padding_length), device=truncated_mask.device, dtype=truncated_mask.dtype) - truncated_mask = torch.cat([truncated_mask, padding], dim=1) completion_mask = completion_mask & (~truncated_mask) if self.beta != 0.0: diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 0adb9d6930..1c44cd2773 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -74,7 +74,7 @@ def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: torch.Tensor, dim: int): for i in range(cu_seqlens.shape[0] - 1): slices = [slice(None)] * inputs.ndim slices[dim] = slice(cu_seqlens[i], cu_seqlens[i + 1]) - val = inputs[slices] + val = inputs[tuple(slices)] view_shape = (*inputs.shape[:dim], 2 * cp_size, val.shape[dim] // (2 * cp_size), *inputs.shape[dim + 1:]) val = val.view(view_shape) index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu', @@ -104,6 +104,11 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): keys.append('decoder_input') else: keys.append('input_ids') + if hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo': + keys.append('truncated_mask') + keys.append('advantages') + keys.append('completion_mask') + packed_seq_params = batch.get('packed_seq_params') if packed_seq_params is None: return mcore_get_batch_on_this_cp_rank(batch) From ff9b667fc721ce66c929fb8bde97fe8241777174 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Sat, 11 Oct 2025 16:39:14 +0800 Subject: [PATCH 33/83] fix pp+cp --- swift/megatron/trainers/grpo_trainer.py | 21 +++------------------ swift/megatron/trainers/rlhf_mixin.py | 2 ++ 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 4b05ae661d..de3ed6af7b 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -5,7 +5,6 @@ from contextlib import contextmanager, nullcontext from copy import copy, deepcopy from functools import partial -from types import MethodType from typing import Any, Dict, List, Union import torch @@ -24,7 +23,7 @@ from ..argument import MegatronArguments, MegatronRLHFArguments from .rlhf_mixin import MegatronRLHFTrainer from .utils import (gather, gather_object, load_megatron_model_to_gpu, load_megatron_optimizer, log_gpu_memory, - offload_megatron_model_to_cpu, offload_megatron_optimizer, profiling_context, split_cp_inputs) + offload_megatron_model_to_cpu, offload_megatron_optimizer, profiling_context) try: from mbridge import AutoBridge @@ -326,9 +325,6 @@ def _replace_data_iterator(self, data_iterator, model): def _generate_and_score_completions(self, batch): # Get or create the rollout group (TP×PP×CP) - # We need to include CP in the rollout group because: - # 1. Data slicing: each rank (including CP ranks) should get different data - # 2. Gather operations: we need to gather data from all TP×PP×CP ranks rollout_group = self._get_rollout_group() # batch : same across DP groups @@ -336,17 +332,11 @@ def get_local_rollout_batch(batch): # repeat num_generations times global_rollout_batch = [deepcopy(item) for item in batch for _ in range(self.num_generations)] # get local rollout data - # For rollout, we need to distribute data across TP×PP×CP ranks - # Note: During rollout (vLLM inference), each GPU processes different data - # CP will only take effect during training (forward/backward) rollout_rank = torch.distributed.get_rank(group=rollout_group) rollout_group_size = torch.distributed.get_world_size(group=rollout_group) per_device_batch_size = self.per_device_generation_batch_size - assert rollout_group_size * per_device_batch_size == len(global_rollout_batch), ( - f'rollout_group_size ({rollout_group_size}) * per_device_batch_size ({per_device_batch_size}) ' - f'!= len(global_rollout_batch) ({len(global_rollout_batch)}). ' - f'rollout_rank={rollout_rank}') + assert rollout_group_size * per_device_batch_size == len(global_rollout_batch) data_slice = slice(rollout_rank * per_device_batch_size, (rollout_rank + 1) * per_device_batch_size) rollout_batch = global_rollout_batch[data_slice] return rollout_batch @@ -366,8 +356,6 @@ def _get_encoded_batch(rollout_batch, advantages): encoded_batch = [template.encode(data, return_length=True) for data in rollout_batch] encoded_batch = to_device(template.data_collator(encoded_batch), self.device) labels = encoded_batch['labels'] - # TODO: logits_to_keep - # logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item() assert self.template.padding_free position_ids = encoded_batch.get('text_position_ids') if position_ids is None: @@ -699,9 +687,6 @@ def _disable_maxlength_template_context(self, template: Template): def _maybe_replace_response_token(self, batch): # maybe replace the response token with the response token ids to avoid repetitive tokenize - # ignore when loss_scale is set - if self.template.loss_scale.name != 'last_round': - return batch for data in batch: if 'response_token_ids' in data and data['response_token_ids']: loss_mask = None @@ -777,7 +762,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): per_token_logps = self.get_logps( output_tensor, labels, packed_seq_params, packed_seq_params.num_samples, per_token=True) - if self.args.overlong_filter and any(truncated_mask): + if self.args.overlong_filter and truncated_mask.any(): completion_mask = completion_mask & (~truncated_mask) if self.beta != 0.0: diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py index 7c937ab308..55e4ae6b42 100644 --- a/swift/megatron/trainers/rlhf_mixin.py +++ b/swift/megatron/trainers/rlhf_mixin.py @@ -59,6 +59,8 @@ def _forward_step_helper(model, inputs): if mpu.is_pipeline_first_stage(): micro_batch_size = 1 # use qkv_format 'thd' seq_length = inputs['input_ids'].shape[1] + if 'position_ids' in inputs: + seq_length = inputs['position_ids'].shape[-1] if args.sequence_parallel: seq_length //= mpu.get_tensor_model_parallel_world_size() recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size], From ba4bfbfe6f9a8eea1bd3d1e4cdd1ceb74d89562f Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Sat, 11 Oct 2025 23:30:15 +0800 Subject: [PATCH 34/83] lora wip --- swift/megatron/trainers/grpo_trainer.py | 66 +++++++++++++++++++++++-- swift/megatron/tuners/lora.py | 34 +++++++++++++ 2 files changed, 97 insertions(+), 3 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index de3ed6af7b..7dd9415d5b 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -167,12 +167,25 @@ def prepare_vllm(self): return engine def _move_model_to_vllm(self): - # TODO: LoRA, server + # TODO: server if self.bridge is None: self.bridge = AutoBridge.from_pretrained(self.hf_model_dir) self._patch_mbridge(self.bridge) - per_tensor_params = self.bridge.export_weights(self.unwrapped_models) - self.engine.inner_model.load_weights(per_tensor_params) + + # Handle LoRA: merge adapters before exporting weights + is_lora_training = self.args.train_type == 'lora' + if is_lora_training: + logger.info('Detected LoRA training mode. Merging LoRA adapters before weight export...') + self._merge_lora_adapters() + + try: + per_tensor_params = self.bridge.export_weights(self.unwrapped_models) + self.engine.inner_model.load_weights(per_tensor_params) + finally: + # Unmerge adapters to restore training state + if is_lora_training: + logger.info('Unmerging LoRA adapters to restore training state...') + self._unmerge_lora_adapters() def _prepare_rewards(self): # TODO: reward model @@ -221,8 +234,27 @@ def _prepare_rewards(self): assert self.reward_funcs, 'reward_funcs is not set' + def _merge_lora_adapters(self): + """Merge LoRA adapters into base model weights for vLLM inference.""" + from ..tuners import LoraParallelLinear + for model in self.unwrapped_models: + for module in model.modules(): + if isinstance(module, LoraParallelLinear): + # Merge all active adapters + module.merge() + + def _unmerge_lora_adapters(self): + """Unmerge LoRA adapters to restore training state.""" + from ..tuners import LoraParallelLinear + for model in self.unwrapped_models: + for module in model.modules(): + if isinstance(module, LoraParallelLinear): + # Unmerge to restore separate LoRA weights for training + module.unmerge() + def _patch_mbridge(self, bridge): original_method = bridge._weight_to_hf_format + original_export = bridge.export_weights def _weight_to_hf_format_patched(mcore_weights_name, mcore_weights): # skip ViT weights @@ -232,7 +264,35 @@ def _weight_to_hf_format_patched(mcore_weights_name, mcore_weights): return [mcore_weights_name], [mcore_weights] return original_method(mcore_weights_name, mcore_weights) + def export_weights_patched(models): + """Patched export_weights that filters out LoRA parameters and cleans names.""" + for name, param in original_export(models): + # Skip LoRA-related parameters (lora_A, lora_B) + # These should not be exported as they are already merged into base weights + if 'lora_A.' in name or 'lora_B.' in name: + logger.debug(f'Skipping LoRA parameter during export: {name}') + continue + # Skip lora embedding parameters if any + if 'lora_embedding_A' in name or 'lora_embedding_B' in name: + logger.debug(f'Skipping LoRA embedding parameter during export: {name}') + continue + + # Clean LoRA-specific prefixes from parameter names + # LoRA wraps base layers, adding '.base_layer' to the parameter path + # We need to remove this so mbridge can recognize standard Megatron parameter names + if '.base_layer.' in name: + name = name.replace('.base_layer.', '.') + logger.debug(f'Cleaned LoRA base_layer from parameter name: {name}') + + # Handle modules_to_save if needed + if '.modules_to_save.default.' in name: + name = name.replace('.modules_to_save.default.', '.') + logger.debug(f'Cleaned modules_to_save from parameter name: {name}') + + yield name, param + bridge._weight_to_hf_format = _weight_to_hf_format_patched + bridge.export_weights = export_weights_patched def _get_rollout_group(self): """ diff --git a/swift/megatron/tuners/lora.py b/swift/megatron/tuners/lora.py index f9ad78ef50..1ef284a3d7 100644 --- a/swift/megatron/tuners/lora.py +++ b/swift/megatron/tuners/lora.py @@ -422,6 +422,40 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N if origin_device.type == 'cpu': self.to(device=origin_device) + def unmerge(self) -> None: + """ + Unmerge all merged adapter weights from the base weights. + + This method reverses the merge operation by subtracting the LoRA delta weights + from the base layer weights, restoring the original base weights. + """ + if not self.merged: + # No adapters to unmerge + return + + base_layer = self.get_base_layer() + origin_device = base_layer.weight0.device if self.is_grouped else base_layer.weight.device + if origin_device.type == 'cpu': + self.to(device=get_current_device()) + + for active_adapter in self.merged_adapters: + if active_adapter in self.lora_A.keys(): + if self.is_grouped: + orig_weights = [getattr(base_layer, f'weight{i}') for i in range(base_layer.num_gemms)] + else: + orig_weights = [base_layer.weight] + + delta_weights = self.get_delta_weights(active_adapter) + for orig_weight, delta_weight in zip(orig_weights, delta_weights): + # Subtract the delta weight to unmerge + orig_weight.data -= delta_weight + + # Clear the merged adapters list + self.merged_adapters = [] + + if origin_device.type == 'cpu': + self.to(device=origin_device) + def dispatch_megatron( target: torch.nn.Module, From e22c7901681046385ac5f06ecbbe512dacf51e35 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 13 Oct 2025 21:12:32 +0800 Subject: [PATCH 35/83] arguments document --- .../Instruction/GRPO/AdvancedResearch/GSPO.md | 2 +- ...44\350\241\214\345\217\202\346\225\260.md" | 20 ++++++------ ...44\350\241\214\345\217\202\346\225\260.md" | 31 +++++++++++++++++- .../Instruction/Command-line-parameters.md | 22 +++++++------ .../Megatron-SWIFT/Command-line-parameters.md | 32 ++++++++++++++++++- 5 files changed, 85 insertions(+), 22 deletions(-) diff --git a/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md b/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md index 1f21f2abfe..6dc03118e2 100644 --- a/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md +++ b/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md @@ -54,7 +54,7 @@ importance_weights = torch.exp(log_importance_weights) - `importance_sampling_level sequence` (GSPO) - `importance_sampling_level sequence_token` (GSPO-token) -其中 sequence_token 要求 ms-swift > 3.7 (源码安装) +其中 sequence_token 要求 ms-swift >= 3.8 论文其他超参 ```bash diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index ba6256e2d0..bcc7325f87 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -505,6 +505,15 @@ reward模型参数将在PPO、GRPO中使用。 #### GRPO参数 - beta: KL正则系数,默认为0.04,设置为0时不加载ref model。 +- epsilon: clip 系数,默认为0.2。 +- epsilon_high: upper clip 系数,默认为None,设置后与epsilon共同构成[epsilon, epsilon_high]裁剪范围。 +- delta: [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291)中双侧 GRPO 上界裁剪值。若设置,建议大于 1 + epsilon。默认为None。 +- overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。 +- dynamic_sample:筛除group内奖励标准差为0的数据,额外采样新数据,默认为False。 +- max_resample_times:dynamic_sample设置下限制重采样次数,默认3次。 +- top_entropy_quantile: 仅对熵值处于前指定分位的 token 参与损失计算,默认为1.0,即不过滤低熵 token,具体参考[文档](./GRPO/AdvancedResearch/entropy_mask.md) +- log_entropy: 记录训练中的熵值变化动态,默认为False,具体参考[文档](./GRPO/GetStarted/GRPO.md#logged-metrics) +- importance_sampling_level: 控制重要性采样比计算,可选项为 `token` 、 `sequence` 和 `sequence_token`,默认为`token`。具体参考[GSPO文档](./GRPO/AdvancedResearch/GSPO.md) - per_device_train_batch_size: 每个设备训练批量大小,在GRPO中,指 completion 的批次大小。 - per_device_eval_batch_size: 每个设备评估批量大小,在GRPO中,指 completion 的批次大小。 - generation_batch_size: 采样completion批量大小,需要是 num_processes * per_device_train_batch_size 的倍数,默认等于 per_device_batch_size * gradient_accumulation_steps * num_processes @@ -542,22 +551,15 @@ reward模型参数将在PPO、GRPO中使用。 - completion_length_limit_scope: 在多轮对话中,`max_completion_length` 的限制范围。 `total`限制所有对话轮次的总输出长度不超过`max_completion_length`, `per_round`限制每一轮的输出长度。 - num_iterations: 每个批次代更新次数,默认为1。 -- epsilon: clip 系数,默认为0.2。 -- epsilon_high: upper clip 系数,默认为None,设置后与epsilon共同构成[epsilon, epsilon_high]裁剪范围。 -- delta: [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291)中双侧 GRPO 上界裁剪值。若设置,建议大于 1 + epsilon。默认为None。 - sync_ref_model: 是否定期同步ref_model,默认为False。 - ref_model_mixup_alpha: 控制在更新过程中model和先前ref_model之间的混合。更新公式为 $π_{ref} = α * π_θ + (1 - α) * π_{ref_{prev}}$。默认为0.6。 - ref_model_sync_steps:同步频率,默认为512。 - move_model_batches: 在模型向vLLM等快速推理框架移动参数时,将layers分为多少个batch. 默认为None, 代表整个模型不进行拆分,否则拆分为move_model_batches+1(非layer参数)+1(多模态部分参数)个。注意:该参数仅对LoRA(PEFT)训练有意义。 - multi_turn_scheduler: 多轮GRPO参数, 传入对应的plugin名称, 同时在plugin/multi_turn.py中添加好对应的实现。 - max_turns: 多轮GRPO的轮数上限。默认为None,不做限制。 -- dynamic_sample:筛除group内奖励标准差为0的数据,额外采样新数据,默认为False。 -- max_resample_times:dynamic_sample设置下限制重采样次数,默认3次。 -- overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。 -- top_entropy_quantile: 仅对熵值处于前指定分位的 token 参与损失计算,默认为1.0,即不过滤低熵 token,具体参考[文档](./GRPO/AdvancedResearch/entropy_mask.md) -- log_entropy: 记录训练中的熵值变化动态,默认为False,具体参考[文档](./GRPO/GetStarted/GRPO.md#logged-metrics) -- importance_sampling_level: 控制重要性采样比计算,可选项为 `token` 和 `sequence`,`token` 模式下保留原始的每个 token 的对数概率比,`sequence` 模式下则会对序列中所有有效 token 的对数概率比进行平均。[GSPO论文](https://www.arxiv.org/abs/2507.18071)中使用sequence级别计算来稳定训练,默认为`token`。 +##### 奖励函数参数 +内置的奖励函数参考[文档](./GRPO/DeveloperGuide/奖励函数.md) cosine 奖励参数 - cosine_min_len_value_wrong:cosine 奖励函数参数,生成错误答案时,最小长度对应的奖励值。默认值为-0.5。 - cosine_max_len_value_wrong:生成错误答案时,最大长度对应的奖励值。默认值为0.0。 diff --git "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 9892d74035..bae33fa5fe 100644 --- "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -236,7 +236,7 @@ lora训练: **DPO参数**: -- ref_load: ref_model的加载路径。采用DPO/KTO算法且使用全参数训练时需要传入。默认为None,即设置为`load`。 +- ref_load: ref_model的加载路径。采用DPO/GRPO/KTO算法且使用全参数训练时需要传入。默认为None,即设置为`load`。 - ref_adapter_load: 加载ref_adapter的权重路径,默认为None。若你要使用SFT产生的LoRA权重进行DPO,请使用"ms-swift>=3.8",并在训练时设置`--adapter_load sft_ckpt --ref_adapter_load sft_ckpt --finetune true`。若是此场景的断点续训,则设置`--adapter_load rlhf_ckpt --ref_adapter_load sft_ckpt --finetune false`。 - beta: 含义与[TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig)相同。控制与参考模型偏差程度的参数。beta值越高,表示与参考模型的偏差越小。对于 IPO 损失函数 (loss_type="ipo"),beta是[论文](https://huggingface.co/papers/2310.12036)中所指的正则化参数。默认为0.1。 - 🔥rpo_alpha: 来自[RPO 论文](https://huggingface.co/papers/2404.19733)中的参数,用于控制损失函数中NLL项的权重(即SFT损失),`loss = dpo_loss + rpo_alpha * sft_loss`,论文中推荐设置为`1.`。默认为`None`,即默认不引入sft_loss。 @@ -254,6 +254,35 @@ lora训练: - desirable_weight: 抵消 desirable 和 undesirable 数量不均衡的影响,对 desirable 损失按该系数进行加权,默认为`1.`。 - undesirable_weight: 抵消 desirable 和 undesirable 数量不均衡的影响,对 undesirable 损失按该系数进行加权,默认为`1.`。 +**GRPO参数** +- ref_load: 含义同DPO。 +- ref_adapter_load: 含义同DPO。 +- beta: KL正则系数,默认为0.04,设置为0时不加载ref model。 +- epsilon: clip 系数,默认为0.2。 +- epsilon_high: upper clip 系数,默认为None,设置后与epsilon共同构成[epsilon, epsilon_high]裁剪范围。 +- overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。 +- importance_sampling_level: 控制重要性采样比计算,可选项为 `token` 、 `sequence` 和 `sequence_token`,默认为`token`。具体参考[GSPO文档](../Instruction/GRPO/AdvancedResearch/GSPO.md) +- batch size 相关参数(注意以下均为 completion-level) + - micro_batch_size: 每个device的批次大小,默认为1。 + - global_batch_size: 总批次大小,等价于`micro_batch_size*数据并行大小*梯度累加步数`。默认为16。对应每次更新权重的训练数据大小(mini_batch_size) + - generation_batch_size: 采样批量大小,需要是global_batch_size的倍数,默认等于global_batch_size + - steps_per_generation:每轮生成的优化步数,即采样批量大小相对global_batch_size的倍数,默认为1。 + - num_generations:每个prompt采样的数量,论文中的G值。采样批量大小需被num_generations 整除。默认为 8。 +- reward_funcs: GRPO算法奖励函数,可选项为`accuracy`、`format`、`cosine`、`repetition`和`soft_overlong`,见swift/plugin/orm.py。你也可以在plugin中自定义自己的奖励函数。默认为`[]`。 +- reward_weights: 每个奖励函数的权重。必须与奖励函数和奖励模型的总数量匹配。如果为 None,则所有奖励的权重都相等,为`1.0`。 +- loss_type: loss 归一化的类型,可选项为['grpo', 'bnpo', 'dr_grpo'], 默认为'grpo', 具体查看该[pr](https://github.com/huggingface/trl/pull/3256#discussion_r2033213348)。 +- vllm_mode 参数 + - vllm_gpu_memory_utilization: vllm透传参数,默认为0.9。 + - vllm_max_model_len: vllm透传参数,默认为None。 + - vllm_enforce_eager: vllm透传参数,默认为False。 + - vllm_limit_mm_per_prompt: vllm透传参数,默认为None。 + - vllm_enable_prefix_caching: vllm透传参数,默认为True。 + - sleep_level: 训练时释放 vLLM 显存,可选项为[0, 1], 默认为0,不释放 + - offload_optimizer: 是否在vLLM推理时offload optimizer参数,默认为False。 + - offload_model: 是否在vLLM推理时 offload 模型,默认为False。 + +内置奖励函数参数参考[文档](../Instruction/命令行参数.md#奖励函数参数) + ## 训练参数 Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用dataset、template等参数,也支持ms-swift中的特定模型参数**)。基本参数的内容可以参考[这里](../Instruction/命令行参数.md#基本参数)。此外还包括以下参数: diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 5776c9ccf1..fda80fbbd1 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -515,6 +515,15 @@ The meanings of the following parameters can be referenced [here](https://huggin #### GRPO Arguments - beta: KL regularization coefficient; default 0.04. Setting it to 0 disables the reference model. +- epsilon: epsilon value for clipping. Default is 0.2. +- epsilon_high: Upper clip coefficient, default is None. When set, it forms a clipping range of [epsilon, epsilon_high] together with epsilon. +- delta: Delta value for the upper clipping bound in two-sided GRPO. Recommended to be > 1 + epsilon. This method was introduced in the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291). +- overlong_filter: Skip overlong truncated samples, which will not be included in loss calculation. Default is False. +- dynamic_sample: Exclude data within the group where the reward standard deviation is 0, and additionally sample new data. Default is False. +- max_resample_times: Under the dynamic_sample setting, limit the number of resampling attempts to a maximum of 3. Default is 3 times. +- top_entropy_quantile: Only tokens whose entropy ranks within the specified top quantile are included in the loss calculation. The default is 1.0, which means low-entropy tokens are not filtered. For details, refer to the [documentation](./GRPO/AdvancedResearch/entropy_mask.md). +- log_entropy: Logs the entropy values during training. The default is False. For more information, refer to the [documentation](./GRPO/GetStarted/GRPO.md#logged-metrics). +- importance_sampling_level: Controls how the importance sampling ratio is computed. Options are `token` and `sequence`. In `token` mode, the raw per-token log-probability ratios are used. In `sequence` mode, the log-probability ratios of all valid tokens in the sequence are averaged to produce a single ratio per sequence. The [GSPO paper](https://www.arxiv.org/abs/2507.18071) uses sequence-level importance sampling to stabilize training. The default is `token`. - per_device_train_batch_size: The training batch size per device. In GRPO, this refers to the batch size of completions during training. - per_device_eval_batch_size: The evaluation batch size per device. In GRPO, this refers to the batch size of completions during evaluation. - generation_batch_size: Batch size to use for generation. It defaults to the effective training batch size: per_device_train_batch_size * num_processes * gradient_accumulation_steps` @@ -556,23 +565,16 @@ The meanings of the following parameters can be referenced [here](https://huggin - top_p: Default is 0.9. - repetition_penalty: Repetition penalty term. Default is 1. - num_iterations: number of iterations per batch. Default is 1. -- epsilon: epsilon value for clipping. Default is 0.2. -- epsilon_high: Upper clip coefficient, default is None. When set, it forms a clipping range of [epsilon, epsilon_high] together with epsilon. -- delta: Delta value for the upper clipping bound in two-sided GRPO. Recommended to be > 1 + epsilon. This method was introduced in the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291). + - sync_ref_model: Whether to synchronize the reference model. Default is False。 - ref_model_mixup_alpha: The Parameter controls the mix between the current policy and the previous reference policy during updates. The reference policy is updated according to the equation: $π_{ref} = α * π_θ + (1 - α) * π_{ref_{prev}}$. Default is 0.6. - ref_model_sync_steps:The parameter determines how frequently the current policy is synchronized with the reference policy. Default is 512. - move_model_batches: When moving model parameters to fast inference frameworks such as vLLM/LMDeploy, determines how many batches to divide the layers into. The default is `None`, which means the entire model is not split. Otherwise, the model is split into `move_model_batches + 1` (non-layer parameters) + `1` (multi-modal component parameters) batches. This parameter is only meaningful for LoRA (PEFT). - multi_turn_scheduler: Multi-turn GRPO parameter; pass the corresponding plugin name, and make sure to implement it in plugin/multi_turn.py. - max_turns: Maximum number of rounds for multi-turn GRPO. The default is None, which means there is no limit. -- dynamic_sample: Exclude data within the group where the reward standard deviation is 0, and additionally sample new data. Default is False. -- max_resample_times: Under the dynamic_sample setting, limit the number of resampling attempts to a maximum of 3. Default is 3 times. -- overlong_filter: Skip overlong truncated samples, which will not be included in loss calculation. Default is False. -The hyperparameters for the reward function can be found in the [Built-in Reward Functions section](#built-in-reward-functions). -- top_entropy_quantile: Only tokens whose entropy ranks within the specified top quantile are included in the loss calculation. The default is 1.0, which means low-entropy tokens are not filtered. For details, refer to the [documentation](./GRPO/AdvancedResearch/entropy_mask.md). -- log_entropy: Logs the entropy values during training. The default is False. For more information, refer to the [documentation](./GRPO/GetStarted/GRPO.md#logged-metrics). -- importance_sampling_level: Controls how the importance sampling ratio is computed. Options are `token` and `sequence`. In `token` mode, the raw per-token log-probability ratios are used. In `sequence` mode, the log-probability ratios of all valid tokens in the sequence are averaged to produce a single ratio per sequence. The [GSPO paper](https://www.arxiv.org/abs/2507.18071) uses sequence-level importance sampling to stabilize training. The default is `token`. +##### Reward function parameters +Refer to the [documentation](./GRPO/DeveloperGuide/reward_function.md) for built-in reward functions. cosine reward function arguments - cosine_min_len_value_wrong (default: -0.5): Reward value corresponding to the minimum length when the answer is incorrect. diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 4deaf2fc98..defacd0735 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -250,7 +250,7 @@ LoRA Training: - use_rslora: Default is `False`. Whether to use `RS-LoRA`. **DPO Parameters** -- ref_load: The loading path for the reference model. This must be provided when using DPO/KTO algorithms with full-parameter training. Defaults to `None`, which means it will be set to the same value as `load`. +- ref_load: The loading path for the reference model. This must be provided when using DPO/GRPO/KTO algorithms with full-parameter training. Defaults to `None`, which means it will be set to the same value as `load`. - ref_adapter_load: The path to load the ref_adapter weights, default is `None`. If you want to use LoRA weights generated from SFT for DPO, please use "ms-swift>=3.8" and set `--adapter_load sft_ckpt --ref_adapter_load sft_ckpt --finetune true` during training. For resuming training from a checkpoint in this scenario, set `--adapter_load rlhf_ckpt --ref_adapter_load sft_ckpt --finetune false`. - beta: Has the same meaning as in [TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig). It controls the degree of deviation from the reference model. A higher beta value indicates less deviation from the reference model. For the IPO loss function (`loss_type="ipo"`), beta is the regularization parameter as mentioned in the [paper](https://huggingface.co/papers/2310.12036). Default is 0.1. - 🔥rpo_alpha: A parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) that controls the weight of the NLL term (i.e., the SFT loss) in the loss function, where `loss = dpo_loss + rpo_alpha * sft_loss`. The paper recommends setting it to `1.`. The default value is `None`, meaning the SFT loss is not included by default. @@ -268,6 +268,36 @@ LoRA Training: - desirable_weight: factor to weight desirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`. - undesirable_weight: factor to weight undesirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`. +**GRPO Parameters** +- ref_load: Same meaning as in DPO. +- ref_adapter_load: Same meaning as in DPO. +- beta: KL regularization coefficient, default is 0.04. When set to 0, the reference model is not loaded. +- epsilon: Clip coefficient, default is 0.2. +- epsilon_high: Upper clip coefficient, default is None. When set, forms a clipping range [epsilon, epsilon_high] together with epsilon. +- overlong_filter: Skips samples that are truncated due to excessive length and excludes them from loss computation. Default is False. +- importance_sampling_level: Controls the level at which importance sampling ratios are computed. Options are `token`, `sequence`, and `sequence_token`. Default is `token`. See [GSPO Documentation](../Instruction/GRPO/AdvancedResearch/GSPO.md) for details. +- Batch Size Related Parameters (Note: all are completion-level) + - micro_batch_size: Batch size per device, default is 1. + - global_batch_size: Total batch size, equivalent to `micro_batch_size * data parallelism size * gradient accumulation steps`. Default is 16. Corresponds to the mini_batch_size (number of training samples per weight update). + - generation_batch_size: Sampling batch size, must be a multiple of global_batch_size. Default equals global_batch_size. + - steps_per_generation: Number of optimization steps per generation round, i.e., the ratio of generation_batch_size to global_batch_size. Default is 1. + - num_generations: Number of samples generated per prompt (the "G" value in the paper). generation_batch_size must be divisible by num_generations. Default is 8. +- reward_funcs: Reward functions used in GRPO algorithm. Options include `accuracy`, `format`, `cosine`, `repetition`, and `soft_overlong`, defined in swift/plugin/orm.py. You can also customize your own reward functions in the plugin. Default is `[]`. +- reward_weights: Weights assigned to each reward function. Must match the total number of reward functions and reward models. If None, all rewards are equally weighted with `1.0`. +- loss_type: Type of loss normalization. Options are ['grpo', 'bnpo', 'dr_grpo']. Default is 'grpo'. See this [PR](https://github.com/huggingface/trl/pull/3256#discussion_r2033213348) for details. + +- vLLM Parameters + - vllm_gpu_memory_utilization: Pass-through parameter to vLLM, default is 0.9. + - vllm_max_model_len: Pass-through parameter to vLLM, default is None. + - vllm_enforce_eager: Pass-through parameter to vLLM, default is False. + - vllm_limit_mm_per_prompt: Pass-through parameter to vLLM, default is None. + - vllm_enable_prefix_caching: Pass-through parameter to vLLM, default is True. + - sleep_level: Release vLLM GPU memory during training. Options are [0, 1], default is 0 (no release). + - offload_optimizer: Whether to offload optimizer states during vLLM inference. Default is False. + - offload_model: Whether to offload model weights during vLLM inference. Default is False. + +For built-in reward function parameters, refer to the [documentation](../Instruction/GRPO/DeveloperGuide/reward_function.md). + ## Training Parameters Megatron training parameters are inherited from Megatron parameters and basic parameters (**sharing dataset, template, etc. with ms-swift, and also supporting model-specific parameters from ms-swift**). For details on basic parameters, please refer to [here](../Instruction/Command-line-parameters.md#base-arguments). Additionally, the following parameters are included: From b3de2621d37b337b15a2863a4bf0ac05bc038735 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 14 Oct 2025 16:13:15 +0800 Subject: [PATCH 36/83] wip lora&cp --- swift/megatron/trainers/grpo_trainer.py | 53 +++++++++++++++++++++++-- swift/megatron/trainers/utils.py | 36 +++++++++++++++++ swift/megatron/tuners/lora.py | 34 ++++++++++++++++ 3 files changed, 119 insertions(+), 4 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index de3ed6af7b..f9685556db 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -23,7 +23,8 @@ from ..argument import MegatronArguments, MegatronRLHFArguments from .rlhf_mixin import MegatronRLHFTrainer from .utils import (gather, gather_object, load_megatron_model_to_gpu, load_megatron_optimizer, log_gpu_memory, - offload_megatron_model_to_cpu, offload_megatron_optimizer, profiling_context) + offload_megatron_model_to_cpu, offload_megatron_optimizer, patch_model_for_lora_export, + profiling_context) try: from mbridge import AutoBridge @@ -167,12 +168,32 @@ def prepare_vllm(self): return engine def _move_model_to_vllm(self): - # TODO: LoRA, server + # TODO: server if self.bridge is None: self.bridge = AutoBridge.from_pretrained(self.hf_model_dir) self._patch_mbridge(self.bridge) - per_tensor_params = self.bridge.export_weights(self.unwrapped_models) - self.engine.inner_model.load_weights(per_tensor_params) + + # Handle LoRA: merge adapters before exporting weights + is_lora_training = self.args.train_type == 'lora' + restore_funcs = [] + + try: + if is_lora_training: + self._merge_lora_adapters() + for model in self.unwrapped_models: + restore_func = patch_model_for_lora_export(model) + restore_funcs.append(restore_func) + + per_tensor_params = self.bridge.export_weights(self.unwrapped_models) + self.engine.inner_model.load_weights(per_tensor_params) + finally: + for restore_func in restore_funcs: + restore_func() + + # Unmerge adapters to restore training state + if is_lora_training: + logger.info('Unmerging LoRA adapters to restore training state...') + self._unmerge_lora_adapters() def _prepare_rewards(self): # TODO: reward model @@ -221,6 +242,24 @@ def _prepare_rewards(self): assert self.reward_funcs, 'reward_funcs is not set' + def _merge_lora_adapters(self): + """Merge LoRA adapters into base model weights for vLLM inference.""" + from ..tuners import LoraParallelLinear + for model in self.unwrapped_models: + for module in model.modules(): + if isinstance(module, LoraParallelLinear): + # Merge all active adapters + module.merge() + + def _unmerge_lora_adapters(self): + """Unmerge LoRA adapters to restore training state.""" + from ..tuners import LoraParallelLinear + for model in self.unwrapped_models: + for module in model.modules(): + if isinstance(module, LoraParallelLinear): + # Unmerge to restore separate LoRA weights for training + module.unmerge() + def _patch_mbridge(self, bridge): original_method = bridge._weight_to_hf_format @@ -230,6 +269,12 @@ def _weight_to_hf_format_patched(mcore_weights_name, mcore_weights): if 'visual.visual' in mcore_weights_name: mcore_weights_name = mcore_weights_name.replace('visual.visual', 'visual') return [mcore_weights_name], [mcore_weights] + + if '.base_layer.' in mcore_weights_name: + mcore_weights_name = mcore_weights_name.replace('.base_layer.', '.') + + if '.modules_to_save.default.' in mcore_weights_name: + mcore_weights_name = mcore_weights_name.replace('.modules_to_save.default.', '.') return original_method(mcore_weights_name, mcore_weights) bridge._weight_to_hf_format = _weight_to_hf_format_patched diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 1c44cd2773..77cce7ec7f 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -352,3 +352,39 @@ def log_gpu_memory(prefix: str = ''): logger.info(f'{prefix} GPU memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, ' f'{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved') + + +def should_filter_lora_parameter(name: str) -> bool: + if 'lora_' in name: + return True + + if 'original_module' in name: + return True + return False + + +def patch_model_for_lora_export(model): + original_named_parameters = model.named_parameters + original_state_dict = model.state_dict + + def filtered_named_parameters(*args, **kwargs): + for name, param in original_named_parameters(*args, **kwargs): + if not should_filter_lora_parameter(name): + yield name, param + + def filtered_state_dict(*args, **kwargs): + state_dict = original_state_dict(*args, **kwargs) + filtered = {} + for name, param in state_dict.items(): + if not should_filter_lora_parameter(name): + filtered[name] = param + return filtered + + model.named_parameters = filtered_named_parameters + model.state_dict = filtered_state_dict + + def restore(): + model.named_parameters = original_named_parameters + model.state_dict = original_state_dict + + return restore diff --git a/swift/megatron/tuners/lora.py b/swift/megatron/tuners/lora.py index f9ad78ef50..1ef284a3d7 100644 --- a/swift/megatron/tuners/lora.py +++ b/swift/megatron/tuners/lora.py @@ -422,6 +422,40 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N if origin_device.type == 'cpu': self.to(device=origin_device) + def unmerge(self) -> None: + """ + Unmerge all merged adapter weights from the base weights. + + This method reverses the merge operation by subtracting the LoRA delta weights + from the base layer weights, restoring the original base weights. + """ + if not self.merged: + # No adapters to unmerge + return + + base_layer = self.get_base_layer() + origin_device = base_layer.weight0.device if self.is_grouped else base_layer.weight.device + if origin_device.type == 'cpu': + self.to(device=get_current_device()) + + for active_adapter in self.merged_adapters: + if active_adapter in self.lora_A.keys(): + if self.is_grouped: + orig_weights = [getattr(base_layer, f'weight{i}') for i in range(base_layer.num_gemms)] + else: + orig_weights = [base_layer.weight] + + delta_weights = self.get_delta_weights(active_adapter) + for orig_weight, delta_weight in zip(orig_weights, delta_weights): + # Subtract the delta weight to unmerge + orig_weight.data -= delta_weight + + # Clear the merged adapters list + self.merged_adapters = [] + + if origin_device.type == 'cpu': + self.to(device=origin_device) + def dispatch_megatron( target: torch.nn.Module, From fe3270f33d62617318c71be758d69df8ece236f6 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 14 Oct 2025 16:42:05 +0800 Subject: [PATCH 37/83] remove unused patch --- swift/megatron/trainers/grpo_trainer.py | 29 ------------------------- 1 file changed, 29 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index a52050475d..f9685556db 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -262,7 +262,6 @@ def _unmerge_lora_adapters(self): def _patch_mbridge(self, bridge): original_method = bridge._weight_to_hf_format - original_export = bridge.export_weights def _weight_to_hf_format_patched(mcore_weights_name, mcore_weights): # skip ViT weights @@ -278,35 +277,7 @@ def _weight_to_hf_format_patched(mcore_weights_name, mcore_weights): mcore_weights_name = mcore_weights_name.replace('.modules_to_save.default.', '.') return original_method(mcore_weights_name, mcore_weights) - def export_weights_patched(models): - """Patched export_weights that filters out LoRA parameters and cleans names.""" - for name, param in original_export(models): - # Skip LoRA-related parameters (lora_A, lora_B) - # These should not be exported as they are already merged into base weights - if 'lora_A.' in name or 'lora_B.' in name: - logger.debug(f'Skipping LoRA parameter during export: {name}') - continue - # Skip lora embedding parameters if any - if 'lora_embedding_A' in name or 'lora_embedding_B' in name: - logger.debug(f'Skipping LoRA embedding parameter during export: {name}') - continue - - # Clean LoRA-specific prefixes from parameter names - # LoRA wraps base layers, adding '.base_layer' to the parameter path - # We need to remove this so mbridge can recognize standard Megatron parameter names - if '.base_layer.' in name: - name = name.replace('.base_layer.', '.') - logger.debug(f'Cleaned LoRA base_layer from parameter name: {name}') - - # Handle modules_to_save if needed - if '.modules_to_save.default.' in name: - name = name.replace('.modules_to_save.default.', '.') - logger.debug(f'Cleaned modules_to_save from parameter name: {name}') - - yield name, param - bridge._weight_to_hf_format = _weight_to_hf_format_patched - bridge.export_weights = export_weights_patched def _get_rollout_group(self): """ From ca9c9bcf56298c3f241d0b14b181d1a15b0acea8 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 29 Oct 2025 15:47:05 +0800 Subject: [PATCH 38/83] wip server --- swift/megatron/trainers/grpo_trainer.py | 35 ++++++++++++++----------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index f9685556db..0a2de750fe 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn +from accelerate.utils import broadcast_object_list from megatron.core import mpu from megatron.training import get_args, training from trl.trainer.grpo_trainer import nanstd @@ -41,7 +42,6 @@ def __init__(self, args: MegatronRLHFArguments, template: Template): self.args = args self.hf_model_dir = args.model_info.model_dir self.processing_class = self.template.processor - # TODO: multi turn scheduler(colocate multi turn) self._prepare_template_data_collator() self._init_grpo_params() self._prepare_rewards() @@ -102,6 +102,7 @@ def _init_grpo_params(self): return_details=True) self._step = 0 + self._last_loaded_step = -1 self._rollout_group = None # Will be lazily initialized def _prepare_rollout_engine(self): @@ -118,12 +119,13 @@ def _prepare_rollout_engine(self): self.enable_server_multi_turn = False # TODO # for multi-turn server, maybe the num of rollout outputs is not equal to the num of rollout inputs self.dynamic_num_samples = False - if self.use_vllm: - if not is_vllm_available(): - raise ImportError('vLLM is not available and `use_vllm` is set to True. ' - 'Please install vLLM with `pip install vllm -U` to use it.') - assert self.vllm_mode == 'colocate' # TODO: server mode - + assert self.use_vllm + if not is_vllm_available(): + raise ImportError('vLLM is not available and `use_vllm` is set to True. ' + 'Please install vLLM with `pip install vllm -U` to use it.') + if self.vllm_mode == 'server': + self.vllm_client = self.args.vllm_client + elif self.vllm_mode == 'colocate': if not self.world_size % self.vllm_tensor_parallel_size == 0: raise ValueError(f'vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size ' f'({self.world_size}) evenly.') @@ -136,6 +138,8 @@ def _prepare_rollout_engine(self): if self.args.sleep_level > 0: self.engine.engine.sleep(self.args.sleep_level) log_gpu_memory('after sleep vLLM engine') + else: + raise ValueError(f'Invalid vllm_mode: {self.vllm_mode}') def prepare_vllm(self): from swift.llm.infer.infer_engine import GRPOVllmEngine @@ -468,14 +472,10 @@ def _generate_completions(self, batch): Returns: batch: The input batch with rollout completion results merged in. - - Note: - Currently only supports colocate mode. Server mode support is planned for future implementation. """ # TODO: server mode - assert self.vllm_mode == 'colocate' # Step 1: Wake up the engine if it's sleeping (vLLM colocate mode) - if self.engine.inner_model_executor.is_sleeping: + if self.vllm_mode == 'colocate' and self.engine.inner_model_executor.is_sleeping: wake_up_params = inspect.signature(self.engine.engine.wake_up).parameters # Load weights only (faster and reduces memory peak) kwargs = {'tags': ['weights']} if 'tags' in wake_up_params else {} @@ -483,11 +483,13 @@ def _generate_completions(self, batch): log_gpu_memory(f'after wake up vLLM engine with {kwargs}') # Step 2: Load model weights - self._move_model_to_vllm() + if self._step != self._last_loaded_step: + self._move_model_to_vllm() + self._last_loaded_step = self._step context = self.offload_context if self.enable_offload else nullcontext with context(): - if (self.engine.inner_model_executor.is_sleeping + if (self.vllm_mode == 'colocate' and self.engine.inner_model_executor.is_sleeping and 'tags' in inspect.signature(self.engine.engine.wake_up).parameters): self.engine.engine.wake_up(tags=['kv_cache']) log_gpu_memory('after wake up vLLM engine with kv_cache') @@ -497,8 +499,9 @@ def _generate_completions(self, batch): outputs: List[RolloutOutput] = self._rollout(batch) # Step4: Sleep to release memory - if self.args.sleep_level > 0: - self.engine.engine.sleep(self.args.sleep_level) + if self.vllm_mode == 'colocate' and self.args.sleep_level > 0: + self.engine.engine.reset_prefix_cache() + self.engine.engine.sleep(level=self.args.sleep_level) log_gpu_memory('after sleep vLLM engine') batch = self.postprocess_rollout_data(batch, outputs) From f258202c3cd612d703380e80fe83ec633ab76cdd Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 29 Oct 2025 16:46:58 +0800 Subject: [PATCH 39/83] wip --- swift/megatron/trainers/grpo_trainer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 0a2de750fe..243edff01a 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -199,6 +199,11 @@ def _move_model_to_vllm(self): logger.info('Unmerging LoRA adapters to restore training state...') self._unmerge_lora_adapters() + if self.vllm_mode == 'server' and self.is_main_process: + self.vllm_client.reset_prefix_cache() + elif self.vllm_mode == 'colocate': + self.engine.engine.reset_prefix_cache() + def _prepare_rewards(self): # TODO: reward model args = self.args @@ -532,7 +537,10 @@ def preprocess_rollout_data(self, batch): def _rollout(self, batch) -> List[RolloutOutput]: request_config = self._get_request_config() # TODO: server mode - rollout_outputs = self._colocate_rollout(batch, request_config) + if self.vllm_mode == 'server': + pass + elif self.vllm_mode == 'colocate': + rollout_outputs = self._colocate_rollout(batch, request_config) return rollout_outputs def postprocess_rollout_data(self, batch, outputs): From 0a38c0c7fa1375c726f24d10f9532aa61391d98c Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 30 Oct 2025 14:14:13 +0800 Subject: [PATCH 40/83] server rollout wip --- swift/megatron/trainers/grpo_trainer.py | 132 ++++++++++++++++++++- swift/trainers/rlhf_trainer/vllm_client.py | 5 + 2 files changed, 133 insertions(+), 4 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 243edff01a..f2f0882602 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import base64 import gc import inspect from collections import defaultdict @@ -10,14 +11,15 @@ import torch import torch.nn as nn from accelerate.utils import broadcast_object_list +from dacite import from_dict from megatron.core import mpu from megatron.training import get_args, training from trl.trainer.grpo_trainer import nanstd from vllm.distributed import parallel_state as vllm_ps -from swift.llm import RequestConfig, RowPreprocessor, Template, to_device +from swift.llm import RequestConfig, RolloutInferRequest, RowPreprocessor, Template, to_device from swift.llm.infer.protocol import RolloutOutput -from swift.plugin import orms +from swift.plugin import MultiTurnScheduler, multi_turns, orms from swift.trainers.rlhf_trainer.grpo_trainer import DataType from swift.trainers.rlhf_trainer.utils import replace_assistant_response_with_ids from swift.utils import get_current_device, get_logger, is_vllm_available, remove_response @@ -45,6 +47,7 @@ def __init__(self, args: MegatronRLHFArguments, template: Template): self._prepare_template_data_collator() self._init_grpo_params() self._prepare_rewards() + self._prepare_scheduler() # TODO self._prepare_rollout_engine() # debug: use mbridge to convert mcore to hf self.bridge = None @@ -118,7 +121,6 @@ def _prepare_rollout_engine(self): self.use_gym_env = False self.enable_server_multi_turn = False # TODO # for multi-turn server, maybe the num of rollout outputs is not equal to the num of rollout inputs - self.dynamic_num_samples = False assert self.use_vllm if not is_vllm_available(): raise ImportError('vLLM is not available and `use_vllm` is set to True. ' @@ -189,6 +191,7 @@ def _move_model_to_vllm(self): restore_funcs.append(restore_func) per_tensor_params = self.bridge.export_weights(self.unwrapped_models) + # TODO: server mode self.engine.inner_model.load_weights(per_tensor_params) finally: for restore_func in restore_funcs: @@ -251,6 +254,23 @@ def _prepare_rewards(self): assert self.reward_funcs, 'reward_funcs is not set' + def _prepare_scheduler(self): + """Prepare multi-turn scheduler""" + args = self.args + + self.multi_turn_scheduler = None + if not hasattr(args, 'multi_turn_scheduler'): + return + + if args.multi_turn_scheduler: + if isinstance(args.multi_turn_scheduler, str): + assert args.multi_turn_scheduler in multi_turns + multi_turn_scheduler = multi_turns[args.multi_turn_scheduler](max_turns=args.max_turns) + self.multi_turn_scheduler: MultiTurnScheduler = multi_turn_scheduler + else: + assert isinstance(args.multi_turn_scheduler, MultiTurnScheduler) + self.multi_turn_scheduler: MultiTurnScheduler = args.multi_turn_scheduler + def _merge_lora_adapters(self): """Merge LoRA adapters into base model weights for vLLM inference.""" from ..tuners import LoraParallelLinear @@ -538,7 +558,7 @@ def _rollout(self, batch) -> List[RolloutOutput]: request_config = self._get_request_config() # TODO: server mode if self.vllm_mode == 'server': - pass + self._server_rollout(batch, request_config) elif self.vllm_mode == 'colocate': rollout_outputs = self._colocate_rollout(batch, request_config) return rollout_outputs @@ -615,8 +635,59 @@ def _get_request_config(self) -> RequestConfig: return request_config + def _server_rollout(self, + inputs: DataType, + request_config: RequestConfig, + is_global_inputs: bool = False) -> List[RolloutOutput]: + # TODO: async generate + infer_requests = self.inputs2requests(inputs) + + if is_global_inputs: + per_device_size = len(infer_requests) // self.world_size + all_requests = infer_requests + all_requests_lengths = [per_device_size] + [0] * (self.world_size - 1) + else: + all_requests = gather_object(infer_requests) + all_requests_lengths = gather_object([len(infer_requests)]) + + if not any(requests for requests in all_requests): + return [] + + if self.is_main_process: + all_outputs: List[RolloutOutput] = self.vllm_client.infer( + infer_requests=all_requests, request_config=request_config) + assert len(all_outputs) == len(all_requests) # TODO: dynamic num of samples + else: + all_outputs = [None] * len(all_requests) + + if not is_global_inputs: + all_outputs = broadcast_object_list(all_outputs, from_process=0) + start_idx = sum(all_requests_lengths[:self.process_index]) + end_idx = start_idx + all_requests_lengths[self.process_index] + outputs = all_outputs[start_idx:end_idx] + else: + outputs = all_outputs if self.is_main_process else [] + return outputs + def _colocate_rollout(self, batch, request_config: RequestConfig): + if self.vllm_tensor_parallel_size > 1: + local_rank_in_group = torch.distributed.get_rank(group=self.vllm_tp_group) + local_input_length = len(batch) + all_input_lengths = [None] * self.vllm_tensor_parallel_size + torch.distributed.all_gather_object(all_input_lengths, local_input_length, group=self.vllm_tp_group) + + start_idx = sum(all_input_lengths[:local_rank_in_group]) + end_idx = start_idx + all_input_lengths[local_rank_in_group] + + gathered_batch = [None for _ in range(self.vllm_tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_batch, batch, group=self.vllm_tp_group) + batch = [p for sublist in gathered_batch for p in sublist] + outputs: List[RolloutOutput] = self.engine.infer(infer_requests=batch, request_config=request_config) + + if self.vllm_tensor_parallel_size > 1: + outputs = outputs[start_idx:end_idx] + return outputs def _score_completions(self, inputs: DataType) -> torch.Tensor: @@ -950,3 +1021,56 @@ def offload_context(self): if getattr(self, 'optimizer', None) and self.args.offload_optimizer: load_megatron_optimizer(self.optimizer) log_gpu_memory('after load optimizer to gpu') + + def inputs2requests(self, inputs: DataType) -> List[RolloutInferRequest]: + """Convert raw input data into RolloutInferRequest objects""" + + def _process_image_data(image_data: Union[dict, str]) -> str: + if isinstance(image_data, dict): + if image_data.get('bytes'): + return base64.b64encode(image_data['bytes']).decode('utf-8') + if image_data.get('path'): + return image_data['path'] + return image_data + + if not inputs: + return [] + args = self.args + + REQUEST_METADATA_FIELDS = ['messages', 'images', 'audios', 'videos', 'tools', 'objects', 'uuid'] + requests_dicts = [] + + for data in inputs: + request_data = {key: data[key] for key in REQUEST_METADATA_FIELDS if key in data and data[key] is not None} + if 'uuid' not in request_data: + request_data['uuid'] = data['request_id'] + if hasattr(args, 'vllm_server_pass_dataset') and args.vllm_server_pass_dataset: + extra_fields = { + k: v + for k, v in data.items() if k not in REQUEST_METADATA_FIELDS and data[k] is not None + } + if extra_fields: + request_data['data_dict'] = extra_fields + elif self.multi_turn_scheduler: + base_data_dict = {} + if 'data_dict' in data: + if isinstance(data['data_dict'], dict): + base_data_dict = data['data_dict'] + else: + raise ValueError('data_dict exists but is not a dictionary') + extra_data = { + k: v + for k, v in data.items() + if k not in REQUEST_METADATA_FIELDS and k != 'data_dict' and data[k] is not None + } + final_data_dict = {**extra_data, **base_data_dict} + request_data['data_dict'] = final_data_dict if final_data_dict else {} + + requests_dicts.append(request_data) + + for request in requests_dicts: + if 'images' in request and request['images']: + request['images'] = ([_process_image_data(img) for img in request['images']] if isinstance( + request['images'], list) else _process_image_data(request['images'])) + + return [from_dict(RolloutInferRequest, request_data) for request_data in requests_dicts] diff --git a/swift/trainers/rlhf_trainer/vllm_client.py b/swift/trainers/rlhf_trainer/vllm_client.py index 8440a220bb..a5adab3d4f 100644 --- a/swift/trainers/rlhf_trainer/vllm_client.py +++ b/swift/trainers/rlhf_trainer/vllm_client.py @@ -133,9 +133,14 @@ def infer( results = [None] * self.num_servers errors = [None] * self.num_servers + if isinstance(request_config, RequestConfig): + request_config = asdict(request_config) def process_chunk(i, chunk): try: + if len(chunk) > 0 and isinstance(chunk[0], RolloutInferRequest): + chunk = [asdict(req) for req in chunk] + response = self.sessions[i].post( f'{self.base_urls[i]}/infer/', json={ From 5f2f349ade67db8bac58d2cbe74514aa5114b152 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 4 Nov 2025 18:15:55 +0800 Subject: [PATCH 41/83] move vllm client init out of args --- swift/megatron/argument/megatron_args.py | 25 ++++++------------------ swift/megatron/train/rlhf.py | 25 +++++++++++++++++++++--- swift/megatron/trainers/grpo_trainer.py | 11 ++++++++--- 3 files changed, 36 insertions(+), 25 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 17e30901a2..0fc2965c64 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -67,7 +67,6 @@ class RLHFMegatronArgumentsMixin: vllm_server_host: Optional[List[str]] = None vllm_server_port: List[int] = field(default_factory=lambda: [8000]) vllm_server_timeout: float = 240.0 - vllm_client: Optional[object] = field(init=False, default=None) # ─────────────────────────── Reward ─────────────────────────── reward_funcs: List[str] = field(default_factory=list) @@ -151,20 +150,6 @@ def __post_init__(self): def _init_grpo(self): - def _init_external_vllm(): - if self.rlhf_type != 'grpo' or (self.vllm_server_host is None and self.vllm_server_base_url is None): - return - from swift.trainers.rlhf_trainer.vllm_client import VLLMClient - if is_master(): - logger.info('Start connecting to vLLM server') - self.vllm_client = VLLMClient( - base_urls=self.vllm_server_base_url, - hosts=self.vllm_server_host, - server_ports=self.vllm_server_port, - connection_timeout=self.vllm_server_timeout) - self.vllm_client.init_communicator(device=get_current_device()) - logger.info('Connected to vLLM server') - def _check_not_supported(): pass @@ -184,12 +169,14 @@ def _check_batch_params(): raise ValueError( "'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time") world_size = torch.distributed.get_world_size() - assert self.generation_batch_size % world_size == 0, \ - f'generation_batch_size ({self.generation_batch_size}) ' \ + num_rollout_prompt = self.generation_batch_size // self.num_generations + assert num_rollout_prompt % world_size == 0, ( + f'num_rollout_prompt ({num_rollout_prompt}) = generation_batch_size ' + f'({self.generation_batch_size}) // num_generations ({self.num_generations}) ' f'must be divisible by the world size ({world_size})' - self.per_device_generation_batch_size = self.generation_batch_size // world_size + f'please adjust generation_batch_size/steps_per_generation/num_generations to make it divisible') + self.per_device_generation_batch_size = num_rollout_prompt // world_size - _init_external_vllm() _check_not_supported() _check_batch_params() # default loss_type if no loss_type is provided diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index 4d864b7595..21cd9ceb92 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -3,7 +3,7 @@ from swift.llm.train.kto import prepare_kto_dataset from swift.trainers.rlhf_trainer.utils import identity_data_collator -from swift.utils import get_logger +from swift.utils import get_current_device, get_logger, is_master from ..argument import MegatronRLHFArguments from ..trainers import MegatronDPOTrainer, MegatronGRPOTrainer, MegatronKTOTrainer, MegatronRewardTrainer from .sft import MegatronSft @@ -26,8 +26,10 @@ def prepare_trainer(self): trainer_cls = trainer_mapping.get(args.rlhf_type) if trainer_cls is None: raise ValueError(f'The current Megatron-SWIFT does not support rlhf_type: {args.rlhf_type}.') - - return trainer_cls(args, self.template) + kwargs = {} + if args.rlhf_type == 'grpo': + kwargs['vllm_client'] = self._prepare_vllm_client() + return trainer_cls(args, self.template, **kwargs) def _prepare_template(self) -> None: super()._prepare_template() @@ -46,6 +48,23 @@ def _get_dataset(self): train_dataset, val_dataset = prepare_kto_dataset(args, train_dataset, val_dataset) return train_dataset, val_dataset + def _prepare_vllm_client(self): + if self.args.rlhf_type != 'grpo' or (self.args.vllm_mode != 'server'): + return + from swift.trainers.rlhf_trainer.vllm_client import VLLMClient + vllm_client = None + if is_master(): + logger.info('Start connecting to vLLM server') + vllm_client = VLLMClient( + base_urls=self.args.vllm_server_base_url, + hosts=self.args.vllm_server_host, + server_ports=self.args.vllm_server_port, + connection_timeout=self.args.vllm_server_timeout) + vllm_client.close_communicator() + vllm_client.init_communicator(device=get_current_device()) + logger.info('Connected to vLLM server') + return vllm_client + def megatron_rlhf_main(args: Optional[Union[List[str], MegatronRLHFArguments]] = None): return MegatronRLHF(args).main() diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index f2f0882602..674c5740da 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -39,7 +39,8 @@ class MegatronGRPOTrainer(MegatronRLHFTrainer): - def __init__(self, args: MegatronRLHFArguments, template: Template): + def __init__(self, args: MegatronRLHFArguments, template: Template, **kwargs): + self.vllm_client = kwargs.pop('vllm_client') super().__init__(args, template) self.args = args self.hf_model_dir = args.model_info.model_dir @@ -126,7 +127,7 @@ def _prepare_rollout_engine(self): raise ImportError('vLLM is not available and `use_vllm` is set to True. ' 'Please install vLLM with `pip install vllm -U` to use it.') if self.vllm_mode == 'server': - self.vllm_client = self.args.vllm_client + pass elif self.vllm_mode == 'colocate': if not self.world_size % self.vllm_tensor_parallel_size == 0: raise ValueError(f'vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size ' @@ -175,6 +176,9 @@ def prepare_vllm(self): def _move_model_to_vllm(self): # TODO: server + if self.vllm_mode == 'server': + return # TODO + if self.bridge is None: self.bridge = AutoBridge.from_pretrained(self.hf_model_dir) self._patch_mbridge(self.bridge) @@ -368,7 +372,8 @@ def _replace_data_iterator(self, data_iterator, model): if self._step % self.steps_per_generation == 0: # each rollout DP group will generate generation_batch_size / world_size completions - completions_to_rollout = self.generation_batch_size // mpu.get_data_parallel_world_size() + dp_size = mpu.get_data_parallel_world_size() + completions_to_rollout = self.generation_batch_size // dp_size # completions will be repeated num_generations times after # so we need to divide num_iters_per_step by num_generations to get prompt batch size prompts_to_rollout = completions_to_rollout // self.num_generations From 416feb2f078d6c58f44081905aa4254427b57dff Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 4 Nov 2025 19:10:32 +0800 Subject: [PATCH 42/83] server mode --- swift/megatron/argument/megatron_args.py | 2 +- swift/megatron/trainers/grpo_trainer.py | 35 +++++++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 0fc2965c64..1487f90956 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -175,7 +175,7 @@ def _check_batch_params(): f'({self.generation_batch_size}) // num_generations ({self.num_generations}) ' f'must be divisible by the world size ({world_size})' f'please adjust generation_batch_size/steps_per_generation/num_generations to make it divisible') - self.per_device_generation_batch_size = num_rollout_prompt // world_size + self.per_device_generation_batch_size = self.generation_batch_size // world_size _check_not_supported() _check_batch_params() diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 674c5740da..eef3000fa6 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -2,12 +2,14 @@ import base64 import gc import inspect +import uuid from collections import defaultdict from contextlib import contextmanager, nullcontext from copy import copy, deepcopy from functools import partial from typing import Any, Dict, List, Union +import json import torch import torch.nn as nn from accelerate.utils import broadcast_object_list @@ -504,6 +506,8 @@ def _generate_completions(self, batch): batch: The input batch with rollout completion results merged in. """ # TODO: server mode + # add prompt ids and system prompts + batch = self._preprocess_inputs(batch) # Step 1: Wake up the engine if it's sleeping (vLLM colocate mode) if self.vllm_mode == 'colocate' and self.engine.inner_model_executor.is_sleeping: wake_up_params = inspect.signature(self.engine.engine.wake_up).parameters @@ -563,7 +567,7 @@ def _rollout(self, batch) -> List[RolloutOutput]: request_config = self._get_request_config() # TODO: server mode if self.vllm_mode == 'server': - self._server_rollout(batch, request_config) + rollout_outputs = self._server_rollout(batch, request_config) elif self.vllm_mode == 'colocate': rollout_outputs = self._colocate_rollout(batch, request_config) return rollout_outputs @@ -1079,3 +1083,32 @@ def _process_image_data(image_data: Union[dict, str]) -> str: request['images'], list) else _process_image_data(request['images'])) return [from_dict(RolloutInferRequest, request_data) for request_data in requests_dicts] + + def _preprocess_inputs(self, inputs: DataType) -> DataType: + """Preprocess inputs before inference""" + processed_inputs = self._add_prompt_id_to_inputs(inputs) + for input_item in processed_inputs: + remove_response(input_item['messages']) + return processed_inputs + + def _add_prompt_id_to_inputs(self, inputs: DataType) -> DataType: + """Add unique prompt_id and request_id to each input""" + if not inputs: + return inputs + + all_messages = gather_object([inp['messages'] for inp in inputs]) + messages_to_prompt_id = {} + prompt_id_counter = 0 + + for messages in all_messages: + key = json.dumps(messages) + if key not in messages_to_prompt_id: + messages_to_prompt_id[key] = f'prompt_{prompt_id_counter}' + prompt_id_counter += 1 + + for input_item in inputs: + messages = input_item.get('messages') + input_item['prompt_id'] = messages_to_prompt_id[json.dumps(messages)] + input_item['request_id'] = f'chatcmpl-{str(uuid.uuid4().hex)}' + + return inputs From b93c031cd43ccb22a95214070ab567544477abe7 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 4 Nov 2025 19:59:57 +0800 Subject: [PATCH 43/83] remove old func --- swift/megatron/trainers/rlhf_mixin.py | 31 --------------------------- 1 file changed, 31 deletions(-) diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py index adf0801106..d61c550cea 100644 --- a/swift/megatron/trainers/rlhf_mixin.py +++ b/swift/megatron/trainers/rlhf_mixin.py @@ -50,37 +50,6 @@ def null_ref_context(self): for m in self.peft_models: m.set_adapter('default') - @staticmethod - def _forward_step_helper(model, inputs): - args = get_args() - if mpu.is_pipeline_first_stage(): - micro_batch_size = 1 # use qkv_format 'thd' - seq_length = inputs['input_ids'].shape[1] - if 'position_ids' in inputs: - seq_length = inputs['position_ids'].shape[-1] - if args.sequence_parallel: - seq_length //= mpu.get_tensor_model_parallel_world_size() - recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size], - device=torch.cuda.current_device(), - dtype=torch.int64) - else: - recv_shape_buffer = torch.empty((3, ), device=torch.cuda.current_device(), dtype=torch.int64) - recv_from_prev_pipeline_rank_(recv_shape_buffer) - if not mpu.is_pipeline_last_stage(): - send_to_next_pipeline_rank(recv_shape_buffer) - shape = recv_shape_buffer.tolist() - - if not mpu.is_pipeline_first_stage(): - recv_buffer = torch.empty(shape, device=torch.cuda.current_device(), dtype=args.params_dtype) - recv_from_prev_pipeline_rank_(recv_buffer) - model.set_input_tensor(recv_buffer) - output_tensor = model(**inputs) - if not mpu.is_pipeline_last_stage(): - send_to_next_pipeline_rank(output_tensor) - output_tensor = None - - return output_tensor - def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None, per_token=False): args = get_args() per_token_logps = -output_tensor From 2f5d7b551387d34feee5a4a379537615e31748ac Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 4 Nov 2025 21:19:19 +0800 Subject: [PATCH 44/83] mcore bridge --- swift/megatron/trainers/grpo_trainer.py | 50 ++++++------------------- 1 file changed, 11 insertions(+), 39 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index b09d30407a..de3788519c 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -24,7 +24,7 @@ from swift.plugin import MultiTurnScheduler, multi_turns, orms from swift.trainers.rlhf_trainer.grpo_trainer import DataType from swift.trainers.rlhf_trainer.utils import replace_assistant_response_with_ids -from swift.utils import get_current_device, get_logger, is_vllm_available, remove_response +from swift.utils import get_current_device, get_logger, is_master, is_vllm_available, remove_response from ..argument import MegatronArguments, MegatronRLHFArguments from ..utils import forward_step_helper from .rlhf_mixin import MegatronRLHFTrainer @@ -32,11 +32,6 @@ offload_megatron_model_to_cpu, offload_megatron_optimizer, patch_model_for_lora_export, profiling_context) -try: - from mbridge import AutoBridge -except ImportError: - pass - logger = get_logger() @@ -53,8 +48,7 @@ def __init__(self, args: MegatronRLHFArguments, template: Template, **kwargs): self._prepare_rewards() self._prepare_scheduler() # TODO self._prepare_rollout_engine() - # debug: use mbridge to convert mcore to hf - self.bridge = None + self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} def _prepare_template_data_collator(self): @@ -77,7 +71,7 @@ def _init_grpo_params(self): # distributed params self.world_size = torch.distributed.get_world_size() self.process_index = torch.distributed.get_rank() - self.is_main_process = self.process_index == 0 + self.is_main_process = is_master() self.device = get_current_device() # algorithm params self.num_generations = args.num_generations # G in the GRPO paper @@ -178,14 +172,6 @@ def prepare_vllm(self): return engine def _move_model_to_vllm(self): - # TODO: server - if self.vllm_mode == 'server': - return # TODO - - if self.bridge is None: - self.bridge = AutoBridge.from_pretrained(self.hf_model_dir) - self._patch_mbridge(self.bridge) - # Handle LoRA: merge adapters before exporting weights is_lora_training = self.args.train_type == 'lora' restore_funcs = [] @@ -197,9 +183,14 @@ def _move_model_to_vllm(self): restore_func = patch_model_for_lora_export(model) restore_funcs.append(restore_func) - per_tensor_params = self.bridge.export_weights(self.unwrapped_models) - # TODO: server mode - self.engine.inner_model.load_weights(per_tensor_params) + per_tensor_params = dict(self.bridge.export_weights(self.unwrapped_models)) + + if self.vllm_mode == 'server': + if self.is_main_process: + for name, param in per_tensor_params.items(): + self.vllm_client.update_named_param(name, param) + elif self.vllm_mode == 'colocate': + self.engine.inner_model.load_weights(per_tensor_params) finally: for restore_func in restore_funcs: restore_func() @@ -296,25 +287,6 @@ def _unmerge_lora_adapters(self): # Unmerge to restore separate LoRA weights for training module.unmerge() - def _patch_mbridge(self, bridge): - original_method = bridge._weight_to_hf_format - - def _weight_to_hf_format_patched(mcore_weights_name, mcore_weights): - # skip ViT weights - if 'visual' in mcore_weights_name: - if 'visual.visual' in mcore_weights_name: - mcore_weights_name = mcore_weights_name.replace('visual.visual', 'visual') - return [mcore_weights_name], [mcore_weights] - - if '.base_layer.' in mcore_weights_name: - mcore_weights_name = mcore_weights_name.replace('.base_layer.', '.') - - if '.modules_to_save.default.' in mcore_weights_name: - mcore_weights_name = mcore_weights_name.replace('.modules_to_save.default.', '.') - return original_method(mcore_weights_name, mcore_weights) - - bridge._weight_to_hf_format = _weight_to_hf_format_patched - def _get_rollout_group(self): """ Get or create the rollout process group (TP×PP×CP). From b3b37ced6a221aff8bd30b5c7db8aadcdf808a7e Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 5 Nov 2025 15:43:00 +0800 Subject: [PATCH 45/83] merge main & flatten weight sync --- docs/source/Megatron-SWIFT/Mcore-Bridge.md | 100 ++++++++++++++++- ...44\350\241\214\345\217\202\346\225\260.md" | 3 +- .../Megatron-SWIFT/Command-line-parameters.md | 3 +- docs/source_en/Megatron-SWIFT/Mcore-Bridge.md | 103 +++++++++++++++++- swift/llm/template/template/glm.py | 2 +- swift/llm/template/template/qwen.py | 4 +- swift/megatron/__init__.py | 8 +- swift/megatron/argument/megatron_args.py | 2 +- swift/megatron/export/export.py | 6 +- swift/megatron/model/gpt_bridge.py | 74 +++++++++---- swift/megatron/trainers/base.py | 2 +- swift/megatron/trainers/grpo_trainer.py | 43 ++++++-- swift/megatron/trainers/utils.py | 2 +- swift/megatron/utils/utils.py | 2 +- 14 files changed, 307 insertions(+), 47 deletions(-) diff --git a/docs/source/Megatron-SWIFT/Mcore-Bridge.md b/docs/source/Megatron-SWIFT/Mcore-Bridge.md index abed20034f..f8d3125a70 100644 --- a/docs/source/Megatron-SWIFT/Mcore-Bridge.md +++ b/docs/source/Megatron-SWIFT/Mcore-Bridge.md @@ -130,8 +130,8 @@ swift infer \ Mcore-Bridge除了支持全参数的导入导出,还支持单独对LoRA增量模型进行导入导出。 以下为纯文本模型Qwen3-Moe模型使用LoRA自我认知训练的例子: -- 若你希望导出merge后的权重,而不是LoRA增量权重,请设置`--merge_lora true`。 -- 注意:由于transformers和Megatron模型结构并不一定一致(例如transformers的Qwen3-VL-Moe的专家部分并不是Linear实现,而是Parameters),因此部分模型无法转换(若Qwen3-VL-Moe只设置linear_proj和linear_qkv训练LoRA也支持转换)。但大多数的模型支持LoRA转换,例如:Qwen3-Moe,Qwen3-Omni-Moe,GLM4.5-V等。 +- 若你希望导出merge后的权重,而不是LoRA增量权重,请设置`--merge_lora true`。设置`--merge_lora true`的兼容性更好,支持所有系列模型。 +- 注意:由于transformers和Megatron模型结构并不一定一致(例如transformers的Qwen3-VL-Moe的专家部分并不是Linear实现,而是Parameters),因此部分模型无法转换LoRA增量权重(若Qwen3-VL-Moe只设置linear_proj和linear_qkv训练LoRA也支持转换)。但大多数的模型支持LoRA转换,例如:Qwen3-Moe,Qwen3-Omni-Moe,GLM4.5-V等。 ```shell # 50GiB PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ @@ -273,3 +273,99 @@ megatron export \ --expert_model_parallel_size 2 \ --pipeline_model_parallel_size 2 ``` + + +## 使用代码 + +你需要创建以下文件(test.py),然后运行`CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 test.py`。以下为使用Mcore-Bridge进行权重加载、导出、保存的示例代码。 + +```python +import torch + +from swift.megatron import MegatronArguments, convert_hf_config, get_megatron_model_meta +from swift.llm import get_model_tokenizer +from megatron.training.initialize import initialize_megatron + +_, processor = get_model_tokenizer('Qwen/Qwen3-4B-Instruct-2507', load_model=False, download_model=True) +model_info = processor.model_info +megatron_model_meta = get_megatron_model_meta(model_info.model_type) +config_kwargs = convert_hf_config(model_info.config) +megatron_args = MegatronArguments( + tensor_model_parallel_size=2, + torch_dtype=torch.bfloat16, + **config_kwargs, +) +extra_args = megatron_args.parse_to_megatron() +extra_args['model_info'] = model_info +extra_args['megatron_model_meta'] = megatron_model_meta +initialize_megatron(args_defaults=extra_args) +mg_model = megatron_model_meta.model_provider() +bridge = megatron_model_meta.bridge_cls() +# 加载权重 +bridge.load_weights(mg_model, model_info.model_dir) +# 导出权重 +for name, parameters in bridge.export_weights([mg_model]): + pass +# 保存权重 +bridge.save_weights([mg_model], 'output/Qwen3-4B-Instruct-2507-new') +``` + +推理新产生的权重: +```shell +CUDA_VISIBLE_DEVICES=0 \ +swift infer \ + --model output/Qwen3-4B-Instruct-2507-new \ + --model_type qwen3_nothinking \ + --stream true +``` + +LoRA权重的加载、导出和存储同理,运行`CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 test.py` +```python +import torch + +from swift.megatron import ( + MegatronArguments, convert_hf_config, get_megatron_model_meta, prepare_mcore_model +) +from swift.llm import get_model_tokenizer +from megatron.training.initialize import initialize_megatron + +_, processor = get_model_tokenizer('Qwen/Qwen3-30B-A3B-Instruct-2507', load_model=False, download_model=True) +model_info = processor.model_info +megatron_model_meta = get_megatron_model_meta(model_info.model_type) +config_kwargs = convert_hf_config(model_info.config) +megatron_args = MegatronArguments( + tensor_model_parallel_size=2, + pipeline_model_parallel_size=2, + expert_model_parallel_size=2, + sequence_parallel=True, + moe_grouped_gemm=True, + torch_dtype=torch.bfloat16, + train_type='lora', + **config_kwargs, +) +extra_args = megatron_args.parse_to_megatron() +extra_args['model_info'] = model_info +extra_args['megatron_model_meta'] = megatron_model_meta +initialize_megatron(args_defaults=extra_args) +mg_model = megatron_model_meta.model_provider() +# 加载权重 +bridge = megatron_model_meta.bridge_cls() +bridge.load_weights(mg_model, model_info.model_dir) +# 准备LoRA并加载 +peft_model = prepare_mcore_model(mg_model) +print(f'peft_model: {peft_model}') +# bridge.load_weights(mg_model, 'adapter-path', is_peft_format=True) +# 导出权重 +for name, parameters in bridge.export_weights([mg_model], is_peft_format=True): + pass +bridge.save_weights([mg_model], 'output/Qwen3-30B-A3B-Instruct-2507-lora', is_peft_format=True) +``` + +推理新产生的权重: +```shell +CUDA_VISIBLE_DEVICES=0 \ +swift infer \ + --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ + --adapters output/Qwen3-30B-A3B-Instruct-2507-lora \ + --stream true +``` diff --git "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index f93a873477..e8c191997e 100644 --- "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -196,7 +196,8 @@ - 🔥expert_tensor_parallel_size: 专家TP并行度。默认值为1。 - 在"ms-swift<3.9",其默认值为None,即等于`--tensor_model_parallel_size` 的数值,该默认值将在"ms-swift>=3.9"被修改。 - moe_token_dispatcher_type: 要使用的token分发器类型。可选选项包括 'allgather'、'alltoall'、'flex'和'alltoall_seq'。默认值为'alltoall'。 -- 🔥moe_grouped_gemm: 当每个rank包含多个专家时,通过在多个流中启动多个本地 GEMM 内核,利用 TransformerEngine中的GroupedLinear提高利用率和性能。默认为False。 +- 🔥moe_grouped_gemm: 当每个rank包含多个专家时,通过在多个流中启动多个本地 GEMM 内核,利用 TransformerEngine中的GroupedLinear提高利用率和性能。默认为True。 + - 在"ms-swift>=3.10",该参数默认值从False修改为True。 - 🔥moe_permute_fusion: 在令牌分发过程中融合令牌重排操作。默认为False。 - 🔥moe_aux_loss_coeff: 默认为0,不使用aux_loss。**通常情况下,该值设置的越大,训练效果越差,但MoE负载越均衡**,请根据实验效果,选择合适的值。 - 注意:在"ms-swift<3.7.1",其默认为None,自动从config.json读取。 diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 125a7179e9..ef62f907e4 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -208,7 +208,8 @@ For guidance on selecting parallelization strategies, please refer to the [Train - 🔥expert_tensor_parallel_size: expert tensor-parallel size. Default is 1. - In "ms-swift<3.9", its default is `None`, which means it equals the value of `--tensor_model_parallel_size`. This default will be changed in "ms-swift>=3.9". - moe_token_dispatcher_type: The type of token dispatcher to use. Options include 'allgather', 'alltoall', 'flex', and 'alltoall_seq'. Default is 'alltoall'. -- 🔥moe_grouped_gemm: When each rank contains multiple experts, multiple local GEMM kernels can be launched in parallel streams to improve utilization and performance by using GroupedLinear from TransformerEngine. Default is False. +- 🔥moe_grouped_gemm: When each rank contains multiple experts, multiple local GEMM kernels can be launched in parallel streams to improve utilization and performance by using GroupedLinear from TransformerEngine. Default is True. + - In "ms-swift>=3.10", the default value of this parameter was changed from False to True. - 🔥moe_permute_fusion: Fuses token permutation operations during token dispatch. Default is False. - 🔥moe_aux_loss_coeff: Defaults to 0, meaning the auxiliary loss is not used. **Generally, a higher value leads to worse training performance but more balanced MoE expert utilization.** Please choose an appropriate value based on experimental results. - Note: In ms-swift versions earlier than 3.7.1, the default is None and the value is automatically loaded from config.json. diff --git a/docs/source_en/Megatron-SWIFT/Mcore-Bridge.md b/docs/source_en/Megatron-SWIFT/Mcore-Bridge.md index 215c2d2ccb..daeb205a7a 100644 --- a/docs/source_en/Megatron-SWIFT/Mcore-Bridge.md +++ b/docs/source_en/Megatron-SWIFT/Mcore-Bridge.md @@ -138,8 +138,8 @@ In addition to supporting full parameter import/export, Mcore-Bridge also suppor Below is an example of self-cognition training using LoRA for the text-only model Qwen3-Moe: -- If you want to export merged weights instead of LoRA incremental weights, please set `--merge_lora true`. -- Note: Since transformers and Megatron model structures may not always be consistent (for example, the expert part of Qwen3-VL-Moe in transformers is not implemented as Linear but as Parameters), some models cannot be converted (though Qwen3-VL-Moe supports conversion if only linear_proj and linear_qkv are set for LoRA training). However, most models support LoRA conversion, such as: Qwen3-Moe, Qwen3-Omni-Moe, GLM4.5-V, etc. +- If you want to export merged weights instead of LoRA delta weights, please set `--merge_lora true`. Setting `--merge_lora true` has better compatibility and supports all model series. +- Note: Since the model structures of transformers and Megatron are not necessarily identical (for example, the expert part of transformers' Qwen3-VL-Moe is not implemented as Linear, but as Parameters), some models cannot convert LoRA delta weights (however, if Qwen3-VL-Moe only sets linear_proj and linear_qkv for LoRA training, conversion is also supported). But most models support LoRA conversion, such as: Qwen3-Moe, Qwen3-Omni-Moe, GLM4.5-V, etc. ```shell # 50GiB @@ -285,3 +285,102 @@ megatron export \ --expert_model_parallel_size 2 \ --pipeline_model_parallel_size 2 ``` + + +## Using Code + +You need to create the following file (test.py), then run `CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 test.py`. Below is sample code for loading, exporting, and saving weights using Mcore-Bridge. + +```python +import torch + +from swift.megatron import MegatronArguments, convert_hf_config, get_megatron_model_meta +from swift.llm import get_model_tokenizer +from megatron.training.initialize import initialize_megatron + +_, processor = get_model_tokenizer('Qwen/Qwen3-4B-Instruct-2507', load_model=False, download_model=True) +model_info = processor.model_info +megatron_model_meta = get_megatron_model_meta(model_info.model_type) +config_kwargs = convert_hf_config(model_info.config) +megatron_args = MegatronArguments( + tensor_model_parallel_size=2, + torch_dtype=torch.bfloat16, + **config_kwargs, +) +extra_args = megatron_args.parse_to_megatron() +extra_args['model_info'] = model_info +extra_args['megatron_model_meta'] = megatron_model_meta +initialize_megatron(args_defaults=extra_args) +mg_model = megatron_model_meta.model_provider() +bridge = megatron_model_meta.bridge_cls() +# Load weights +bridge.load_weights(mg_model, model_info.model_dir) +# Export weights +for name, parameters in bridge.export_weights([mg_model]): + pass +# Save weights +bridge.save_weights([mg_model], 'output/Qwen3-4B-Instruct-2507-new') +``` + +Inference with the newly generated weights: + +```shell +CUDA_VISIBLE_DEVICES=0 \ +swift infer \ + --model output/Qwen3-4B-Instruct-2507-new \ + --model_type qwen3_nothinking \ + --stream true +``` + +Loading, exporting, and saving LoRA weights follows the same pattern. Run `CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 test.py` + +```python +import torch + +from swift.megatron import ( + MegatronArguments, convert_hf_config, get_megatron_model_meta, prepare_mcore_model +) +from swift.llm import get_model_tokenizer +from megatron.training.initialize import initialize_megatron + +_, processor = get_model_tokenizer('Qwen/Qwen3-30B-A3B-Instruct-2507', load_model=False, download_model=True) +model_info = processor.model_info +megatron_model_meta = get_megatron_model_meta(model_info.model_type) +config_kwargs = convert_hf_config(model_info.config) +megatron_args = MegatronArguments( + tensor_model_parallel_size=2, + pipeline_model_parallel_size=2, + expert_model_parallel_size=2, + sequence_parallel=True, + moe_grouped_gemm=True, + torch_dtype=torch.bfloat16, + train_type='lora', + **config_kwargs, +) +extra_args = megatron_args.parse_to_megatron() +extra_args['model_info'] = model_info +extra_args['megatron_model_meta'] = megatron_model_meta +initialize_megatron(args_defaults=extra_args) +mg_model = megatron_model_meta.model_provider() +# Load weights +bridge = megatron_model_meta.bridge_cls() +bridge.load_weights(mg_model, model_info.model_dir) +# Prepare LoRA and load +peft_model = prepare_mcore_model(mg_model) +print(f'peft_model: {peft_model}') +# bridge.load_weights(mg_model, 'adapter-path', is_peft_format=True) +# Export weights +for name, parameters in bridge.export_weights([mg_model], is_peft_format=True): + pass +bridge.save_weights([mg_model], 'output/Qwen3-30B-A3B-Instruct-2507-lora', is_peft_format=True) +``` + +Inference with the newly generated weights: + +```shell +CUDA_VISIBLE_DEVICES=0 \ +swift infer \ + --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ + --adapters output/Qwen3-30B-A3B-Instruct-2507-lora \ + --stream true +``` diff --git a/swift/llm/template/template/glm.py b/swift/llm/template/template/glm.py index 2972b67c43..7711d10ddb 100644 --- a/swift/llm/template/template/glm.py +++ b/swift/llm/template/template/glm.py @@ -317,7 +317,7 @@ def packing_row(self, row: List[Dict[str, Any]]) -> Dict[str, Any]: return packed def _get_position_ids(self, inputs: Dict[str, Any]): - base_model = self.get_base_model(self.model) + base_model = self.get_base_model(self._get_model()) attention_mask = inputs.get('attention_mask_2d') if attention_mask is None: attention_mask = inputs.get('attention_mask') diff --git a/swift/llm/template/template/qwen.py b/swift/llm/template/template/qwen.py index 4c32b9ceda..0fde540c9e 100644 --- a/swift/llm/template/template/qwen.py +++ b/swift/llm/template/template/qwen.py @@ -426,7 +426,7 @@ def _get_position_ids(self, inputs: Dict[str, Any]): kwargs = {} if self.version == 'v2_5': kwargs = {'second_per_grid_ts': inputs.get('second_per_grid_ts')} - base_model = self.get_base_model(self.model) + base_model = self.get_base_model(self._get_model()) if hasattr(base_model, 'get_rope_index'): get_rope_index = base_model.get_rope_index else: @@ -775,7 +775,7 @@ def _get_position_ids(self, inputs: Dict[str, Any]): attention_mask = inputs.get('attention_mask') if attention_mask is None: attention_mask = torch.ones_like(input_ids) - position_ids, _ = self.model.thinker.get_rope_index( + position_ids, _ = self._get_model().thinker.get_rope_index( input_ids, inputs.get('image_grid_thw'), inputs.get('video_grid_thw'), diff --git a/swift/megatron/__init__.py b/swift/megatron/__init__.py index 9815a0ea87..ca2988c723 100644 --- a/swift/megatron/__init__.py +++ b/swift/megatron/__init__.py @@ -15,8 +15,8 @@ from .train import megatron_sft_main, megatron_pt_main, megatron_rlhf_main from .export import megatron_export_main from .convert import convert_hf2mcore, convert_mcore2hf - from .utils import prepare_mcore_model, adapter_state_dict_context - from .argument import MegatronTrainArguments, MegatronRLHFArguments, MegatronExportArguments + from .utils import prepare_mcore_model, adapter_state_dict_context, convert_hf_config + from .argument import MegatronTrainArguments, MegatronRLHFArguments, MegatronExportArguments, MegatronArguments from .model import MegatronModelType, MegatronModelMeta, get_megatron_model_meta, register_megatron_model from .trainers import MegatronTrainer, MegatronDPOTrainer from .tuners import LoraParallelLinear @@ -25,8 +25,8 @@ 'train': ['megatron_sft_main', 'megatron_pt_main', 'megatron_rlhf_main'], 'export': ['megatron_export_main'], 'convert': ['convert_hf2mcore', 'convert_mcore2hf'], - 'utils': ['prepare_mcore_model', 'adapter_state_dict_context'], - 'argument': ['MegatronTrainArguments', 'MegatronRLHFArguments', 'MegatronExportArguments'], + 'utils': ['prepare_mcore_model', 'adapter_state_dict_context', 'convert_hf_config'], + 'argument': ['MegatronTrainArguments', 'MegatronRLHFArguments', 'MegatronExportArguments', 'MegatronArguments'], 'model': ['MegatronModelType', 'MegatronModelMeta', 'get_megatron_model_meta', 'register_megatron_model'], 'trainers': ['MegatronTrainer', 'MegatronDPOTrainer'], 'tuners': ['LoraParallelLinear'], diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index c4d29aa485..8a4aff42d2 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -452,7 +452,7 @@ class MegatronArguments(ExtraMegatronArguments): expert_tensor_parallel_size: int = 1 moe_token_dispatcher_type: Literal['allgather', 'alltoall', 'flex', 'alltoall_seq'] = 'alltoall' moe_enable_deepep: bool = False - moe_grouped_gemm: bool = False + moe_grouped_gemm: bool = True moe_permute_fusion: bool = False moe_aux_loss_coeff: float = 0. moe_z_loss_coeff: Optional[float] = None diff --git a/swift/megatron/export/export.py b/swift/megatron/export/export.py index 9054a4b44a..d5c09b3975 100644 --- a/swift/megatron/export/export.py +++ b/swift/megatron/export/export.py @@ -70,7 +70,8 @@ def convert_mcore2hf(self) -> None: kwargs = {'adapters': [args.save]} else: kwargs = {'model': args.save} - hf_model = prepare_model_template(args, device_map='cpu', **kwargs)[0] if is_last_rank() else None + hf_model, template = prepare_model_template( + args, device_map='cpu', **kwargs) if is_last_rank() else (None, template) test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) dist.barrier() @@ -111,7 +112,8 @@ def convert_hf2mcore(self) -> None: logger.info('Successfully transferred HF model weights to MG model.') if args.test_convert_precision: with disable_safe_ddp_context_use_barrier(): - hf_model = prepare_model_template(args, device_map='cpu')[0] if is_last_rank() else None + hf_model, template = prepare_model_template( + args, device_map='cpu') if is_last_rank() else (None, template) test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) dist.barrier() args.save_args(args.save) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 52578a9ff2..4970909b16 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -465,10 +465,6 @@ def _set_mlp_state(self, to_mcore: bool, ep_rank: Optional[int] = None, hf_mlp=None): - if to_mcore: - hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) - else: - hf_state_dict = {} if hf_mlp is None: hf_mlp = self.hf_layers[layer_idx].mlp is_expert = ep_rank is not None @@ -478,6 +474,10 @@ def _set_mlp_state(self, hf_grouped = not hasattr(hf_mlp.experts, '__len__') hf_mlp = hf_mlp.experts if hf_grouped else hf_mlp.experts[0] num_local_experts = self.args.num_experts // self.ep_size + if to_mcore or hf_grouped: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} # linear_fc1 if to_mcore: if isinstance(mg_mlp.linear_fc1, LoraParallelLinear): @@ -544,6 +544,8 @@ def _set_mlp_state(self, if is_expert: if hf_grouped: gate_up_proj_weight = hf_state_dict['gate_up_proj'].load().transpose(1, 2) + gate_up_proj_weight = gate_up_proj_weight[ep_rank * num_local_experts:(ep_rank + 1) + * num_local_experts] gate_up_proj_weight = gate_up_proj_weight.reshape(num_local_experts * 2, -1, gate_up_proj_weight.shape[-1]) else: @@ -587,6 +589,7 @@ def _set_mlp_state(self, if is_expert and self.ep_size > 1: dist.all_reduce(is_lora, group=self.ep_group) if is_lora: + assert not hf_grouped, 'Currently, hf_grouped with LoRA is not supported.' if mg_mlp is None: lora_A = None lora_B = None @@ -641,8 +644,11 @@ def _set_mlp_state(self, fc1_weight = None else: if is_expert: - fc1_weight = torch.concat( - [getattr(mg_mlp.linear_fc1, f'weight{i}') for i in range(num_local_experts)], dim=0) + linear_fc1 = mg_mlp.linear_fc1 + if isinstance(linear_fc1, LoraParallelLinear): + linear_fc1 = linear_fc1.base_layer + fc1_weight = torch.concat([getattr(linear_fc1, f'weight{i}') for i in range(num_local_experts)], + dim=0) else: fc1_weight = mg_mlp.linear_fc1.weight fc1_weight = fc1_weight.view(num_local_experts * 2, -1, fc1_weight.shape[1]) @@ -653,9 +659,16 @@ def _set_mlp_state(self, if is_expert: gate_up_proj_weight = gate_up_proj_weight.view(num_local_experts, -1, gate_up_proj_weight.shape[-1]) - for i in range(num_local_experts): - hf_i = i + ep_rank * num_local_experts - hf_state_dict[f'{hf_i}.gate_up_proj.weight'] = gate_up_proj_weight[i].clone() + if hf_grouped: + gate_up_proj_weight = gate_up_proj_weight.transpose(1, 2) + if 'gate_up_proj' in hf_state_dict: + gate_up_proj_weight = torch.concat( + [hf_state_dict['gate_up_proj'], gate_up_proj_weight], dim=0) + hf_state_dict['gate_up_proj'] = gate_up_proj_weight.clone() + else: + for i in range(num_local_experts): + hf_i = i + ep_rank * num_local_experts + hf_state_dict[f'{hf_i}.gate_up_proj.weight'] = gate_up_proj_weight[i].clone() del gate_up_proj_weight else: hf_state_dict['gate_up_proj.weight'] = gate_up_proj_weight.view( @@ -702,6 +715,9 @@ def _set_mlp_state(self, fc2_weight = fc2_weight.new_empty(num_local_experts * fc2_weight.shape[0], fc2_weight.shape[1]) if hf_grouped: down_proj_weight = hf_state_dict['down_proj'].load().transpose(1, 2) + down_proj_weight = down_proj_weight[ep_rank * num_local_experts:(ep_rank + 1) + * num_local_experts].reshape( + -1, down_proj_weight.shape[-1]) else: down_proj_weight = torch.concat([ hf_state_dict[f'{i + ep_rank * num_local_experts}.down_proj.weight'].load() @@ -721,6 +737,7 @@ def _set_mlp_state(self, if is_expert and self.ep_size > 1: dist.all_reduce(is_lora, group=self.ep_group) if is_lora: + assert not hf_grouped, 'Currently, hf_grouped with LoRA is not supported.' if mg_mlp is None: lora_A = None lora_B = None @@ -747,15 +764,24 @@ def _set_mlp_state(self, if mg_mlp is None: fc2_weight = None else: - fc2_weight = torch.concat( - [getattr(mg_mlp.linear_fc2, f'weight{i}') for i in range(num_local_experts)], dim=0) + linear_fc2 = mg_mlp.linear_fc2 + if isinstance(linear_fc2, LoraParallelLinear): + linear_fc2 = linear_fc2.base_layer + fc2_weight = torch.concat([getattr(linear_fc2, f'weight{i}') for i in range(num_local_experts)], + dim=0) down_proj_weight = self._get_weight(fc2_weight, 'linear_fc2.weight', is_expert=is_expert) del fc2_weight if down_proj_weight is not None: down_proj_weight = down_proj_weight.view(num_local_experts, -1, down_proj_weight.shape[-1]) - for i in range(num_local_experts): - hf_i = i + ep_rank * num_local_experts - hf_state_dict[f'{hf_i}.down_proj.weight'] = down_proj_weight[i].clone() + if hf_grouped: + down_proj_weight = down_proj_weight.transpose(1, 2) + if 'down_proj' in hf_state_dict: + down_proj_weight = torch.concat([hf_state_dict['down_proj'], down_proj_weight], dim=0) + hf_state_dict['down_proj'] = down_proj_weight.clone() + else: + for i in range(num_local_experts): + hf_i = i + ep_rank * num_local_experts + hf_state_dict[f'{hf_i}.down_proj.weight'] = down_proj_weight[i].clone() else: self._set_state_dict( mg_mlp, 'linear_fc2.weight', hf_state_dict, 'down_proj.weight', to_mcore, is_expert=is_expert) @@ -884,7 +910,7 @@ def _convert_hf_state_dict(self, hf_state_dict, to_mcore): res[k] = v return res - def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool): + def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqdm_desc: str = 'Converting: '): if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) hf_state_dict = self._convert_hf_state_dict(hf_state_dict, to_mcore) @@ -900,7 +926,7 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool): yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) hf_state_dict = {} for layer_idx in tqdm( - range(self.args.num_layers), dynamic_ncols=True, desc='Converting: ', disable=self.disable_tqmd): + range(self.args.num_layers), dynamic_ncols=True, desc=tqdm_desc, disable=self.disable_tqmd): lm_model = getattr(mg_model, 'language_model') if self.is_multimodal else mg_model start_idx = lm_model.decoder.layers[0].layer_number - 1 mg_layer_available = (start_idx <= layer_idx < lm_model.decoder.layers[-1].layer_number) @@ -936,9 +962,14 @@ def load_weights(self, mg_model, hf_model_dir: str, is_peft_format: bool = False with SafetensorLazyLoader(hf_model_dir, is_peft_format=is_peft_format) as loader: state_dict = loader.get_state_dict() hf_prefix = 'base_model.model.' if is_peft_format else '' - list(self._convert([mg_model], state_dict, hf_prefix, True)) - - def export_weights(self, mg_models, target_device=None, only_last_rank: bool = False, is_peft_format: bool = False): + list(self._convert([mg_model], state_dict, hf_prefix, True, 'Loading: ')) + + def export_weights(self, + mg_models, + target_device=None, + only_last_rank: bool = False, + is_peft_format: bool = False, + tqdm_desc: str = 'Exporting: '): # TODO: modules_to_save self._target_device = target_device self._only_last_rank = only_last_rank @@ -946,14 +977,15 @@ def export_weights(self, mg_models, target_device=None, only_last_rank: bool = F self._peft_target_modules = set() self._peft_modules_to_save = set() hf_prefix = 'base_model.model.' if is_peft_format else '' - yield from self._convert(mg_models, {}, hf_prefix, False) + yield from self._convert(mg_models, {}, hf_prefix, False, tqdm_desc=tqdm_desc) def save_weights(self, mg_models, output_dir: str, is_peft_format: bool = False) -> None: """Save the mg_model checkpoint in HF format""" saver = StreamingSafetensorSaver( save_dir=output_dir, max_shard_size=self.args.max_shard_size, is_peft_format=is_peft_format) for k, v in self.export_weights( - mg_models, target_device='cpu', only_last_rank=True, is_peft_format=is_peft_format): + mg_models, target_device='cpu', only_last_rank=True, is_peft_format=is_peft_format, + tqdm_desc='Saving: '): saver.add_tensor(k, v) saver.finalize() if is_last_rank(): diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index de0cea391c..fe45ce3a5d 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -276,7 +276,7 @@ def new_model_provider_func(*_args, **kwargs): with adapter_state_dict_context(): args.iteration, args.num_floating_point_operations_so_far = load_checkpoint( model, optimizer, opt_param_scheduler, load_arg='adapter_load', strict=False) - if args.model_meta.is_multimodal: + if args.megatron_model_meta.is_multimodal: for m in self.unwrapped_models: self._prepare_vit_gradient_checkpointing(m) return model, optimizer, opt_param_scheduler diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index de3788519c..ee97dadd65 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -2,6 +2,7 @@ import base64 import gc import inspect +import os import uuid from collections import defaultdict from contextlib import contextmanager, nullcontext @@ -23,7 +24,9 @@ from swift.llm.infer.protocol import RolloutOutput from swift.plugin import MultiTurnScheduler, multi_turns, orms from swift.trainers.rlhf_trainer.grpo_trainer import DataType -from swift.trainers.rlhf_trainer.utils import replace_assistant_response_with_ids +from swift.trainers.rlhf_trainer.utils import (FlattenedTensorBucket, _create_parameter_buckets, + _process_bucket_with_flattened_tensor, + replace_assistant_response_with_ids) from swift.utils import get_current_device, get_logger, is_master, is_vllm_available, remove_response from ..argument import MegatronArguments, MegatronRLHFArguments from ..utils import forward_step_helper @@ -183,14 +186,12 @@ def _move_model_to_vllm(self): restore_func = patch_model_for_lora_export(model) restore_funcs.append(restore_func) + # Export weights from megatron models using bridge per_tensor_params = dict(self.bridge.export_weights(self.unwrapped_models)) - if self.vllm_mode == 'server': - if self.is_main_process: - for name, param in per_tensor_params.items(): - self.vllm_client.update_named_param(name, param) - elif self.vllm_mode == 'colocate': - self.engine.inner_model.load_weights(per_tensor_params) + # Load weights to vLLM engine + self._load_weights_to_vllm(per_tensor_params) + finally: for restore_func in restore_funcs: restore_func() @@ -200,11 +201,39 @@ def _move_model_to_vllm(self): logger.info('Unmerging LoRA adapters to restore training state...') self._unmerge_lora_adapters() + # Reset prefix cache if self.vllm_mode == 'server' and self.is_main_process: self.vllm_client.reset_prefix_cache() elif self.vllm_mode == 'colocate': self.engine.engine.reset_prefix_cache() + def _load_weights_to_vllm(self, state_dict: Dict[str, torch.Tensor]): + """ + Load state_dict to vLLM engine using flattened tensor optimization. + + For server mode: Uses FlattenedTensorBucket to batch parameters and reduce communication overhead. + For colocate mode: Directly loads weights to the inner model. + """ + if self.vllm_mode == 'server' and self.is_main_process: + # Use flattened tensor optimization for efficient weight transfer + bucket_size_mb = int(os.environ.get('SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE', 512)) + named_params = list(state_dict.items()) + + # Create parameter buckets for efficient processing + parameter_buckets = _create_parameter_buckets(named_params, bucket_size_mb=bucket_size_mb) + + # Process each bucket with flattened tensor + for bucket in parameter_buckets: + _process_bucket_with_flattened_tensor(self, bucket) + + del named_params, parameter_buckets + elif self.vllm_mode == 'colocate': + # Colocate mode: direct weight loading + llm_model = self.engine.inner_model + llm_model.load_weights(state_dict.items()) + + del state_dict + def _prepare_rewards(self): # TODO: reward model args = self.args diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 77cce7ec7f..fe77b7f97b 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -100,7 +100,7 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): if cp_size > 1: args = get_args() keys = ['labels', 'attention_mask', 'position_ids', 'loss_scale'] - if args.model_meta.is_multimodal: + if args.megatron_model_meta.is_multimodal: keys.append('decoder_input') else: keys.append('input_ids') diff --git a/swift/megatron/utils/utils.py b/swift/megatron/utils/utils.py index 9056c6f780..6012c72f71 100644 --- a/swift/megatron/utils/utils.py +++ b/swift/megatron/utils/utils.py @@ -86,7 +86,7 @@ def get_target_modules(args, model): return args.target_modules target_modules = args.target_modules.copy() if 'all-linear' in target_modules: - if args.model_meta.is_multimodal: + if args.megatron_model_meta.is_multimodal: return get_multimodal_target_regex( args, model, From 1d930d8599a7eecc71b4a54017f846404fd0e0ac Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 5 Nov 2025 19:56:24 +0800 Subject: [PATCH 46/83] dynamic sample --- swift/megatron/trainers/grpo_trainer.py | 221 +++++++++++++++++++----- 1 file changed, 182 insertions(+), 39 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index ee97dadd65..5b12caf2cf 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -8,7 +8,7 @@ from contextlib import contextmanager, nullcontext from copy import copy, deepcopy from functools import partial -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Tuple, Union import json import torch @@ -31,9 +31,9 @@ from ..argument import MegatronArguments, MegatronRLHFArguments from ..utils import forward_step_helper from .rlhf_mixin import MegatronRLHFTrainer -from .utils import (gather, gather_object, load_megatron_model_to_gpu, load_megatron_optimizer, log_gpu_memory, - offload_megatron_model_to_cpu, offload_megatron_optimizer, patch_model_for_lora_export, - profiling_context) +from .utils import (gather, gather_object, get_swift_datasets_provider, load_megatron_model_to_gpu, + load_megatron_optimizer, log_gpu_memory, offload_megatron_model_to_cpu, offload_megatron_optimizer, + patch_model_for_lora_export, profiling_context) logger = get_logger() @@ -54,6 +54,13 @@ def __init__(self, args: MegatronRLHFArguments, template: Template, **kwargs): self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} + def train(self, train_dataset, val_dataset, data_collator): + # Store dataset provider for lazy resample iterator initialization + if self.dynamic_sample: + self._train_valid_test_dataset_provider = get_swift_datasets_provider(train_dataset, val_dataset) + self._train_valid_test_dataset_provider.is_distributed = True + super().train(train_dataset, val_dataset, data_collator) + def _prepare_template_data_collator(self): template = self.template args = self.args @@ -87,6 +94,12 @@ def _init_grpo_params(self): self.top_entropy_quantile = args.top_entropy_quantile self.importance_sampling_level = args.importance_sampling_level self.enable_offload = False + + # DAPO, https://arxiv.org/abs/2503.14476 + self.dynamic_sample = args.dynamic_sample + self.max_resample_times = args.max_resample_times + self.overlong_filter = args.overlong_filter + # batch size (completion-level) self.generation_batch_size = args.generation_batch_size self.steps_per_generation = args.steps_per_generation @@ -372,26 +385,57 @@ def _get_rollout_group(self): return self._rollout_group - def _replace_data_iterator(self, data_iterator, model): + def _init_resample_data_iterator(self): + """ + Initialize an independent data iterator for dynamic resampling (lazy initialization). + + This method is called lazily during the first dynamic resampling, ensuring that + pretrain() has already called initialize_megatron() to properly set up all args. + Uses a different seed (args.seed + 1) to avoid overlapping with training samples. + + Note: pretrain() will automatically reset the random seed back to args.seed + after this method completes, so we don't need manual state restoration. + + Args: + train_valid_test_dataset_provider: Dataset provider function + + Returns: + train_data_iterator: Independent data iterator with different random seed + """ + from megatron.training.training import build_train_valid_test_data_iterators + from megatron.training.initialize import _set_random_seed + args = get_args() + + train_valid_test_dataset_provider = self._train_valid_test_dataset_provider + # Use different seed for resample iterator (offset by 1 to avoid overlap) + resample_seed = getattr(args, 'seed', 42) + 1 + try: + # Set new seed for resample iterator creation + _set_random_seed( + resample_seed, + args.data_parallel_random_init, + args.te_rng_tracker, + args.inference_rng_tracker, + use_cudagraphable_rng=args.enable_cuda_graph, + ) + + # Build data iterators with new seed + # TODO: VPP (Virtual Pipeline Parallelism) + resample_data_iterator, _, _ = (build_train_valid_test_data_iterators(train_valid_test_dataset_provider)) + finally: + # Restore original random states to avoid affecting training + _set_random_seed( + args.seed, + args.data_parallel_random_init, + args.te_rng_tracker, + args.inference_rng_tracker, + use_cudagraphable_rng=args.enable_cuda_graph, + ) + return resample_data_iterator + def _replace_data_iterator(self, data_iterator, model): if self._step % self.steps_per_generation == 0: - # each rollout DP group will generate generation_batch_size / world_size completions - dp_size = mpu.get_data_parallel_world_size() - completions_to_rollout = self.generation_batch_size // dp_size - # completions will be repeated num_generations times after - # so we need to divide num_iters_per_step by num_generations to get prompt batch size - prompts_to_rollout = completions_to_rollout // self.num_generations - # every iter will generate micro_batch_size prompts - num_iters_per_step = prompts_to_rollout // self.micro_batch_size - assert num_iters_per_step > 0, ( - f'num_iters_per_step={num_iters_per_step} <= 0. ' - f'This means no prompts will be generated' - f'generation_batch_size={self.generation_batch_size}, ' - f'data_parallel_world_size={mpu.get_data_parallel_world_size()}, ' - f'num_generations={self.num_generations}, ' - f'micro_batch_size={self.micro_batch_size}. ' - 'Please adjust these parameters so that ' - 'generation_batch_size // data_parallel_world_size // num_generations // micro_batch_size >= 1.') + num_iters_per_step = self.get_num_iters_per_step() rollout_batch = [] for _ in range(num_iters_per_step): rollout_batch.extend(next(data_iterator)) @@ -410,28 +454,17 @@ def _generate_and_score_completions(self, batch): # Get or create the rollout group (TP×PP×CP) rollout_group = self._get_rollout_group() - # batch : same across DP groups - def get_local_rollout_batch(batch): - # repeat num_generations times - global_rollout_batch = [deepcopy(item) for item in batch for _ in range(self.num_generations)] - # get local rollout data - rollout_rank = torch.distributed.get_rank(group=rollout_group) - rollout_group_size = torch.distributed.get_world_size(group=rollout_group) - - per_device_batch_size = self.per_device_generation_batch_size - assert rollout_group_size * per_device_batch_size == len(global_rollout_batch) - data_slice = slice(rollout_rank * per_device_batch_size, (rollout_rank + 1) * per_device_batch_size) - rollout_batch = global_rollout_batch[data_slice] - return rollout_batch - - # Step1: Rollout / Reward / Advantage - - rollout_batch = get_local_rollout_batch(batch) + rollout_batch = self.get_local_rollout_batch(batch) rollout_batch = self._generate_completions(rollout_batch) rewards_per_func = self._score_completions(rollout_batch) + # Dynamic sampling for std=0 groups (DAPO) + if self.dynamic_sample: + rollout_batch, rewards_per_func = self._dynamic_sampling(rollout_batch, rewards_per_func, rollout_group, + batch) + advantages = self._compute_advantages(rollout_batch, rewards_per_func) def _get_encoded_batch(rollout_batch, advantages): @@ -797,6 +830,79 @@ def log_rewards_metrics(rewards: torch.Tensor, rewards_per_func_for_metrics: tor return advantages + def _dynamic_sampling(self, rollout_batch: DataType, + rewards_per_func: torch.Tensor) -> Tuple[DataType, torch.Tensor]: + """ + Perform dynamic sampling to replace samples with zero-reward-variance groups. + + This method implements DAPO (https://arxiv.org/abs/2503.14476) by replacing + samples from groups with zero reward variance (std=0) through resampling. + + Args: + rollout_batch: local rollout data samples + rewards_per_func: reward per function for local data samples + rollout_group: rollout communication group + + Returns: + tuple: (rollout_batch, rewards_per_func) with zero-variance groups replaced by resampled data + """ + resample_count = 0 + valid_samples = [] + valid_rewards_per_func = [] + origin_data = (rollout_batch, rewards_per_func) + + while resample_count < self.max_resample_times: + # Gather all samples and rewards across rollout group first + global_rollout_batch = gather_object(rollout_batch) + global_rewards_per_func = gather(rewards_per_func) + + # Compute reward std for the entire global batch + # We need to compute std on the gathered data to get a global mask + global_rewards = (global_rewards_per_func * self.reward_weights.unsqueeze(0)).nansum(dim=1) + grouped_rewards = global_rewards.view(-1, self.num_generations) + group_rewards_std = grouped_rewards.std(dim=1).repeat_interleave(self.num_generations) + global_valid_mask = (group_rewards_std > 0) + + # Filter valid samples based on std > 0 + valid_samples.extend([sample for sample, mask in zip(global_rollout_batch, global_valid_mask) if mask]) + valid_rewards_per_func.append(global_rewards_per_func[global_valid_mask]) + + if len(valid_samples) >= self.generation_batch_size: + break + + # Lazy initialization of resample_data_iterator + # Only initialize when needed, after pretrain() has set up args + if not hasattr(self, 'resample_data_iterator') or self.resample_data_iterator is None: + assert hasattr(self, '_train_valid_test_dataset_provider'), \ + 'Dataset provider not set. Make sure dynamic_sample is enabled.' + self.resample_data_iterator = self._init_resample_data_iterator() + num_iters_per_step = self.get_num_iters_per_step() + next_rollout_prompt_batch = [] + for _ in range(num_iters_per_step): + next_rollout_prompt_batch.extend(next(self.resample_data_iterator)) + + # Repeat num_generations times and get local slice + rollout_batch = self.get_local_rollout_batch(next_rollout_prompt_batch) + + # Generate and score new completions + rollout_batch = self._generate_completions(rollout_batch) + rewards_per_func = self._score_completions(rollout_batch) + resample_count += 1 + + if len(valid_samples) >= self.generation_batch_size: + # Get local slice of valid samples + rank = self.process_index + per_device_batch_size = self.per_device_generation_batch_size + assert self.world_size * per_device_batch_size == len(valid_samples) + data_slice = slice(rank * per_device_batch_size, (rank + 1) * per_device_batch_size) + rollout_batch = valid_samples[:self.generation_batch_size][data_slice] + rewards_per_func = torch.cat(valid_rewards_per_func)[:self.generation_batch_size][data_slice] + else: + logger.warning(f'There are still std=0 groups present after {self.max_resample_times} retries.') + rollout_batch, rewards_per_func = origin_data + + return rollout_batch, rewards_per_func + def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: # TODO: entropy inputs = {k: v for k, v in batch.items() if k not in ['completion_mask', 'advantages', 'truncated_mask']} @@ -1114,3 +1220,40 @@ def _add_prompt_id_to_inputs(self, inputs: DataType) -> DataType: input_item['request_id'] = f'chatcmpl-{str(uuid.uuid4().hex)}' return inputs + + def get_num_iters_per_step(self): + if hasattr(self, '_num_iters_per_step'): + return self._num_iters_per_step + # each rollout DP group will generate generation_batch_size / world_size completions + dp_size = mpu.get_data_parallel_world_size() + completions_to_rollout = self.generation_batch_size // dp_size + # completions will be repeated num_generations times after + # so we need to divide num_iters_per_step by num_generations to get prompt batch size + prompts_to_rollout = completions_to_rollout // self.num_generations + # every iter will generate micro_batch_size prompts + num_iters_per_step = prompts_to_rollout // self.micro_batch_size + assert num_iters_per_step > 0, ( + f'num_iters_per_step={num_iters_per_step} <= 0. ' + f'This means no prompts will be generated' + f'generation_batch_size={self.generation_batch_size}, ' + f'data_parallel_world_size={mpu.get_data_parallel_world_size()}, ' + f'num_generations={self.num_generations}, ' + f'micro_batch_size={self.micro_batch_size}. ' + 'Please adjust these parameters so that ' + 'generation_batch_size // data_parallel_world_size // num_generations // micro_batch_size >= 1.') + self._num_iters_per_step = num_iters_per_step + return num_iters_per_step + + def get_local_rollout_batch(self, batch): + # repeat num_generations times + rollout_group = self._get_rollout_group() + global_rollout_batch = [deepcopy(item) for item in batch for _ in range(self.num_generations)] + # get local rollout data + rollout_rank = torch.distributed.get_rank(group=rollout_group) + rollout_group_size = torch.distributed.get_world_size(group=rollout_group) + + per_device_batch_size = self.per_device_generation_batch_size + assert rollout_group_size * per_device_batch_size == len(global_rollout_batch) + data_slice = slice(rollout_rank * per_device_batch_size, (rollout_rank + 1) * per_device_batch_size) + rollout_batch = global_rollout_batch[data_slice] + return rollout_batch From 5f9e14a77299cda591b37e580e242213a04f5a31 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 6 Nov 2025 00:43:48 +0800 Subject: [PATCH 47/83] fix dynamic sampling --- swift/megatron/trainers/base.py | 1 + swift/megatron/trainers/grpo_trainer.py | 8 +++----- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index fe45ce3a5d..027a458b20 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -766,6 +766,7 @@ def _patch_megatron(self): # support max_epochs self._origin_train_step = training.train_step training.train_step = self.train_step + self._origin_cyclic_iter = training.cyclic_iter training.cyclic_iter = self.new_cyclic_iter # patch training_log self._origin_training_log = training.training_log diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 5b12caf2cf..1d03aa02ea 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -404,6 +404,8 @@ def _init_resample_data_iterator(self): """ from megatron.training.training import build_train_valid_test_data_iterators from megatron.training.initialize import _set_random_seed + from megatron.training import training + training.cyclic_iter = self._origin_cyclic_iter args = get_args() train_valid_test_dataset_provider = self._train_valid_test_dataset_provider @@ -462,8 +464,7 @@ def _generate_and_score_completions(self, batch): # Dynamic sampling for std=0 groups (DAPO) if self.dynamic_sample: - rollout_batch, rewards_per_func = self._dynamic_sampling(rollout_batch, rewards_per_func, rollout_group, - batch) + rollout_batch, rewards_per_func = self._dynamic_sampling(rollout_batch, rewards_per_func) advantages = self._compute_advantages(rollout_batch, rewards_per_func) @@ -873,8 +874,6 @@ def _dynamic_sampling(self, rollout_batch: DataType, # Lazy initialization of resample_data_iterator # Only initialize when needed, after pretrain() has set up args if not hasattr(self, 'resample_data_iterator') or self.resample_data_iterator is None: - assert hasattr(self, '_train_valid_test_dataset_provider'), \ - 'Dataset provider not set. Make sure dynamic_sample is enabled.' self.resample_data_iterator = self._init_resample_data_iterator() num_iters_per_step = self.get_num_iters_per_step() next_rollout_prompt_batch = [] @@ -893,7 +892,6 @@ def _dynamic_sampling(self, rollout_batch: DataType, # Get local slice of valid samples rank = self.process_index per_device_batch_size = self.per_device_generation_batch_size - assert self.world_size * per_device_batch_size == len(valid_samples) data_slice = slice(rank * per_device_batch_size, (rank + 1) * per_device_batch_size) rollout_batch = valid_samples[:self.generation_batch_size][data_slice] rewards_per_func = torch.cat(valid_rewards_per_func)[:self.generation_batch_size][data_slice] From d1460c2b0ab3dca55cc71ec0f9e5c980f6f0b69f Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 7 Nov 2025 11:14:50 +0800 Subject: [PATCH 48/83] fix cp: compute part of seq loss --- swift/megatron/argument/megatron_args.py | 5 ++- swift/megatron/trainers/grpo_trainer.py | 47 ++++++++---------------- swift/megatron/trainers/rlhf_mixin.py | 6 +-- 3 files changed, 22 insertions(+), 36 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 254d60f6d6..1bc45a6ee2 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -169,11 +169,12 @@ def _check_batch_params(): raise ValueError( "'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time") world_size = torch.distributed.get_world_size() + dp_size = world_size // (self.pipeline_model_parallel_size * self.tensor_model_parallel_size) num_rollout_prompt = self.generation_batch_size // self.num_generations - assert num_rollout_prompt % world_size == 0, ( + assert num_rollout_prompt % dp_size == 0, ( f'num_rollout_prompt ({num_rollout_prompt}) = generation_batch_size ' f'({self.generation_batch_size}) // num_generations ({self.num_generations}) ' - f'must be divisible by the world size ({world_size})' + f'must be divisible by the dp size ({dp_size})' f'please adjust generation_batch_size/steps_per_generation/num_generations to make it divisible') self.per_device_generation_batch_size = self.generation_batch_size // world_size diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 1d03aa02ea..30629402b4 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -17,6 +17,7 @@ from dacite import from_dict from megatron.core import mpu from megatron.training import get_args, training +from torch.distributed.nn import all_reduce from trl.trainer.grpo_trainer import nanstd from vllm.distributed import parallel_state as vllm_ps @@ -33,7 +34,12 @@ from .rlhf_mixin import MegatronRLHFTrainer from .utils import (gather, gather_object, get_swift_datasets_provider, load_megatron_model_to_gpu, load_megatron_optimizer, log_gpu_memory, offload_megatron_model_to_cpu, offload_megatron_optimizer, - patch_model_for_lora_export, profiling_context) + profiling_context) + +try: + from trl.trainer.utils import entropy_from_logits +except ImportError: + from swift.trainers.rlhf_trainer.utils import entropy_from_logits logger = get_logger() @@ -100,6 +106,10 @@ def _init_grpo_params(self): self.max_resample_times = args.max_resample_times self.overlong_filter = args.overlong_filter + # Entropy mask settings, TODO + self.log_entropy = args.log_entropy + self.compute_entropy = self.log_entropy or self.top_entropy_quantile < 1.0 + # batch size (completion-level) self.generation_batch_size = args.generation_batch_size self.steps_per_generation = args.steps_per_generation @@ -190,14 +200,10 @@ def prepare_vllm(self): def _move_model_to_vllm(self): # Handle LoRA: merge adapters before exporting weights is_lora_training = self.args.train_type == 'lora' - restore_funcs = [] try: if is_lora_training: - self._merge_lora_adapters() - for model in self.unwrapped_models: - restore_func = patch_model_for_lora_export(model) - restore_funcs.append(restore_func) + self.merge_lora_adapters() # Export weights from megatron models using bridge per_tensor_params = dict(self.bridge.export_weights(self.unwrapped_models)) @@ -206,13 +212,9 @@ def _move_model_to_vllm(self): self._load_weights_to_vllm(per_tensor_params) finally: - for restore_func in restore_funcs: - restore_func() - # Unmerge adapters to restore training state if is_lora_training: - logger.info('Unmerging LoRA adapters to restore training state...') - self._unmerge_lora_adapters() + self.unmerge_lora_adapters() # Reset prefix cache if self.vllm_mode == 'server' and self.is_main_process: @@ -311,24 +313,6 @@ def _prepare_scheduler(self): assert isinstance(args.multi_turn_scheduler, MultiTurnScheduler) self.multi_turn_scheduler: MultiTurnScheduler = args.multi_turn_scheduler - def _merge_lora_adapters(self): - """Merge LoRA adapters into base model weights for vLLM inference.""" - from ..tuners import LoraParallelLinear - for model in self.unwrapped_models: - for module in model.modules(): - if isinstance(module, LoraParallelLinear): - # Merge all active adapters - module.merge() - - def _unmerge_lora_adapters(self): - """Unmerge LoRA adapters to restore training state.""" - from ..tuners import LoraParallelLinear - for model in self.unwrapped_models: - for module in model.modules(): - if isinstance(module, LoraParallelLinear): - # Unmerge to restore separate LoRA weights for training - module.unmerge() - def _get_rollout_group(self): """ Get or create the rollout process group (TP×PP×CP). @@ -788,6 +772,7 @@ def maybe_normalize_advantages(advantages: torch.Tensor, rewards_std: torch.Tens return advantages / (rewards_std + 1e-4) return advantages + mode = 'train' if self.unwrapped_models[0].training else 'eval' assert len(batch) == rewards_per_func.shape[0] total_rewards_per_func = gather(rewards_per_func) rewards = (total_rewards_per_func * self.reward_weights.unsqueeze(0)).nansum(dim=1) @@ -807,7 +792,6 @@ def log_rewards_metrics(rewards: torch.Tensor, rewards_per_func_for_metrics: tor """Log reward statistics for monitoring. Only log once per unique request_id.""" # rewards: [prompt_batch_size, self.num_generations] # rewards_per_func_for_metrics: [prompt_batch_size*self.num_generations, self.num_reward_funcs] - mode = 'train' if self.unwrapped_models[0].training else 'eval' group_rewards = rewards.view(-1, self.num_generations) rewards_mean = group_rewards.mean(-1).mean().item() rewards_std = group_rewards.std(-1).mean().item() @@ -1032,7 +1016,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): else: raise ValueError( f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " - "and 'sequence'.") + ",'sequence' and 'sequence_token'.") coef_1 = torch.exp(log_importance_weights) coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) @@ -1062,6 +1046,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): else: raise ValueError(f'Unknown loss type: {self.loss_type}') + loss = all_reduce(loss, group=mpu.get_context_parallel_group()) # loss = loss.mean() avg_metric = { 'loss': loss.clone().detach(), diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py index a9cf8dae5b..47df78fcdc 100644 --- a/swift/megatron/trainers/rlhf_mixin.py +++ b/swift/megatron/trainers/rlhf_mixin.py @@ -1,11 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from contextlib import contextmanager +import torch +import torch.distributed as dist from megatron.core import mpu from megatron.training import get_args, get_model from megatron.training.checkpointing import load_checkpoint from megatron.training.utils import unwrap_model -from torch.distributed.nn import all_reduce +from torch.distributed.nn import all_gather, all_reduce from transformers.utils import ContextManagers from swift.utils import get_logger @@ -60,8 +62,6 @@ def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None, loss_mask = labels != -100 per_token_logps = per_token_logps * loss_mask if per_token: - if args.context_parallel_size > 1: - per_token_logps = all_reduce(per_token_logps, group=mpu.get_context_parallel_group()) return per_token_logps if num_samples is None: From db7000d7798c8d5c86154eb3a1f9c3f5fe37bc00 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 7 Nov 2025 16:24:31 +0800 Subject: [PATCH 49/83] optimize weight sync & fix vllm_tp --- swift/megatron/argument/megatron_args.py | 5 +- swift/megatron/trainers/grpo_trainer.py | 163 +++++++++++++++-------- swift/megatron/trainers/utils.py | 11 +- 3 files changed, 120 insertions(+), 59 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 1bc45a6ee2..905384dd5c 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -169,7 +169,10 @@ def _check_batch_params(): raise ValueError( "'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time") world_size = torch.distributed.get_world_size() - dp_size = world_size // (self.pipeline_model_parallel_size * self.tensor_model_parallel_size) + # total_model_size = TP × PP × CP, + # data_parallel_size = world_size // total_model_size + dp_size = world_size // ( + self.pipeline_model_parallel_size * self.tensor_model_parallel_size * self.context_parallel_size) num_rollout_prompt = self.generation_batch_size // self.num_generations assert num_rollout_prompt % dp_size == 0, ( f'num_rollout_prompt ({num_rollout_prompt}) = generation_batch_size ' diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 30629402b4..27de3c333c 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -26,8 +26,8 @@ from swift.plugin import MultiTurnScheduler, multi_turns, orms from swift.trainers.rlhf_trainer.grpo_trainer import DataType from swift.trainers.rlhf_trainer.utils import (FlattenedTensorBucket, _create_parameter_buckets, - _process_bucket_with_flattened_tensor, - replace_assistant_response_with_ids) + _process_bucket_with_flattened_tensor, aggressive_empty_cache, + replace_assistant_response_with_ids, set_expandable_segments) from swift.utils import get_current_device, get_logger, is_master, is_vllm_available, remove_response from ..argument import MegatronArguments, MegatronRLHFArguments from ..utils import forward_step_helper @@ -160,10 +160,12 @@ def _prepare_rollout_engine(self): context = self.offload_context if self.enable_offload else nullcontext with context(): + set_expandable_segments(False) self.engine = self.prepare_vllm() if self.args.sleep_level > 0: self.engine.engine.sleep(self.args.sleep_level) log_gpu_memory('after sleep vLLM engine') + set_expandable_segments(True) else: raise ValueError(f'Invalid vllm_mode: {self.vllm_mode}') @@ -205,11 +207,8 @@ def _move_model_to_vllm(self): if is_lora_training: self.merge_lora_adapters() - # Export weights from megatron models using bridge - per_tensor_params = dict(self.bridge.export_weights(self.unwrapped_models)) - - # Load weights to vLLM engine - self._load_weights_to_vllm(per_tensor_params) + # Export and load weights incrementally to avoid memory spikes + self._export_and_load_weights_incrementally() finally: # Unmerge adapters to restore training state @@ -222,32 +221,73 @@ def _move_model_to_vllm(self): elif self.vllm_mode == 'colocate': self.engine.engine.reset_prefix_cache() - def _load_weights_to_vllm(self, state_dict: Dict[str, torch.Tensor]): + def _export_and_load_weights_incrementally(self): """ - Load state_dict to vLLM engine using flattened tensor optimization. + Export weights from Megatron models and load to vLLM incrementally. - For server mode: Uses FlattenedTensorBucket to batch parameters and reduce communication overhead. - For colocate mode: Directly loads weights to the inner model. + For colocate mode: llm_model.load_weights accepts an iterator, so pass it directly. + For server mode: Process weights in buckets to avoid memory spikes. """ - if self.vllm_mode == 'server' and self.is_main_process: - # Use flattened tensor optimization for efficient weight transfer - bucket_size_mb = int(os.environ.get('SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE', 512)) - named_params = list(state_dict.items()) + # Export weights returns an iterator + weight_iterator = self.bridge.export_weights(self.unwrapped_models) - # Create parameter buckets for efficient processing - parameter_buckets = _create_parameter_buckets(named_params, bucket_size_mb=bucket_size_mb) + if self.vllm_mode == 'colocate': + # Colocate mode: load_weights supports iterator, pass directly + llm_model = self.engine.inner_model + llm_model.load_weights(weight_iterator) + elif self.vllm_mode == 'server' and self.is_main_process: + # Server mode: process in buckets and sync with flattened tensors + self._load_weights_to_server_in_buckets(weight_iterator) - # Process each bucket with flattened tensor - for bucket in parameter_buckets: - _process_bucket_with_flattened_tensor(self, bucket) + def _load_weights_to_server_in_buckets(self, weight_iterator): + """ + Load weights to vLLM server in buckets using FlattenedTensorBucket. - del named_params, parameter_buckets - elif self.vllm_mode == 'colocate': - # Colocate mode: direct weight loading - llm_model = self.engine.inner_model - llm_model.load_weights(state_dict.items()) + Args: + weight_iterator: Iterator of (name, tensor) tuples from export_weights + """ + # Get bucket size from environment or use default + bucket_size_mb = int(os.environ.get('SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE', 512)) + bucket_size_bytes = bucket_size_mb * 1024 * 1024 + + current_bucket = [] + current_size = 0 + + for name, param in weight_iterator: + param_size = param.numel() * param.element_size() + current_bucket.append((name, param)) + current_size += param_size + + # If adding this param would exceed bucket size, process current bucket first + if current_size > bucket_size_bytes and current_bucket: + self._sync_bucket_to_server(current_bucket) + current_bucket = [] + current_size = 0 + + # Process remaining parameters in the last bucket + if current_bucket: + self._sync_bucket_to_server(current_bucket) + + def _sync_bucket_to_server(self, bucket_params: List[Tuple[str, torch.Tensor]]): + """ + Synchronize a bucket of parameters to vLLM server using flattened tensors. + + Args: + bucket_params: List of (name, tensor) tuples to sync + """ + if not bucket_params: + return + + # Create FlattenedTensorBucket for efficient transfer + bucket = FlattenedTensorBucket(named_tensors=bucket_params) + metadatas = bucket.get_metadata() + flattened_tensor = bucket.get_flattened_tensor() + + # Directly call vllm_client to update weights + self.vllm_client.update_flattened_params(metadatas, flattened_tensor) - del state_dict + # Clean up to free memory immediately + del bucket, metadatas, flattened_tensor def _prepare_rewards(self): # TODO: reward model @@ -533,8 +573,9 @@ def _generate_completions(self, batch): wake_up_params = inspect.signature(self.engine.engine.wake_up).parameters # Load weights only (faster and reduces memory peak) kwargs = {'tags': ['weights']} if 'tags' in wake_up_params else {} + log_gpu_memory(f'before wake up vLLM engine {kwargs}') self.engine.engine.wake_up(**kwargs) - log_gpu_memory(f'after wake up vLLM engine with {kwargs}') + log_gpu_memory(f'after wake up vLLM engine {kwargs}') # Step 2: Load model weights if self._step != self._last_loaded_step: @@ -545,8 +586,11 @@ def _generate_completions(self, batch): with context(): if (self.vllm_mode == 'colocate' and self.engine.inner_model_executor.is_sleeping and 'tags' in inspect.signature(self.engine.engine.wake_up).parameters): + aggressive_empty_cache() + set_expandable_segments(False) + log_gpu_memory('before wake up vLLM engine kv_cache') self.engine.engine.wake_up(tags=['kv_cache']) - log_gpu_memory('after wake up vLLM engine with kv_cache') + log_gpu_memory('after wake up vLLM engine kv_cache') # Step3: Rollout batch = self.preprocess_rollout_data(batch) @@ -555,7 +599,10 @@ def _generate_completions(self, batch): # Step4: Sleep to release memory if self.vllm_mode == 'colocate' and self.args.sleep_level > 0: self.engine.engine.reset_prefix_cache() + log_gpu_memory('before sleep vLLM engine') self.engine.engine.sleep(level=self.args.sleep_level) + aggressive_empty_cache() + set_expandable_segments(True) log_gpu_memory('after sleep vLLM engine') batch = self.postprocess_rollout_data(batch, outputs) @@ -699,24 +746,7 @@ def _server_rollout(self, return outputs def _colocate_rollout(self, batch, request_config: RequestConfig): - if self.vllm_tensor_parallel_size > 1: - local_rank_in_group = torch.distributed.get_rank(group=self.vllm_tp_group) - local_input_length = len(batch) - all_input_lengths = [None] * self.vllm_tensor_parallel_size - torch.distributed.all_gather_object(all_input_lengths, local_input_length, group=self.vllm_tp_group) - - start_idx = sum(all_input_lengths[:local_rank_in_group]) - end_idx = start_idx + all_input_lengths[local_rank_in_group] - - gathered_batch = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_batch, batch, group=self.vllm_tp_group) - batch = [p for sublist in gathered_batch for p in sublist] - outputs: List[RolloutOutput] = self.engine.infer(infer_requests=batch, request_config=request_config) - - if self.vllm_tensor_parallel_size > 1: - outputs = outputs[start_idx:end_idx] - return outputs def _score_completions(self, inputs: DataType) -> torch.Tensor: @@ -1037,16 +1067,38 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): if self.loss_type == 'grpo': loss_list = torch.split(per_token_loss.squeeze(0), lengths_with_padding.tolist()) mask_list = torch.split(completion_mask.squeeze(0), lengths_with_padding.tolist()) - sample_loss = [(loss * mask).sum() / mask.sum().clamp(min=1.0) for loss, mask in zip(loss_list, mask_list)] - loss = torch.stack(sample_loss[:micro_batch_size]).mean() + + # In CP mode, aggregate numerator and denominator before division (Megatron standard) + if self.args.context_parallel_size > 1: + # Compute sum and count for each sample on this CP rank + sample_sum_and_count = torch.stack([ + torch.stack([(loss * mask).sum(), mask.sum().clamp(min=1.0)]) + for loss, mask in zip(loss_list[:micro_batch_size], mask_list[:micro_batch_size]) + ]) # Shape: [micro_batch_size, 2] + + # All-reduce to aggregate across CP ranks + all_reduce(sample_sum_and_count, group=mpu.get_context_parallel_group()) + + # Now compute per-sample loss and average + sample_loss = sample_sum_and_count[:, 0] / sample_sum_and_count[:, 1] + loss = sample_loss.mean() + else: + sample_loss = [(loss * mask).sum() / mask.sum().clamp(min=1.0) + for loss, mask in zip(loss_list[:micro_batch_size], mask_list[:micro_batch_size])] + loss = torch.stack(sample_loss).mean() elif self.loss_type == 'bnpo': - loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + if self.args.context_parallel_size > 1: + # Aggregate numerator and denominator across CP ranks + loss_and_count = torch.stack([(per_token_loss * completion_mask).sum(), + completion_mask.sum().clamp(min=1.0)]) + all_reduce(loss_and_count, group=mpu.get_context_parallel_group()) + loss = loss_and_count[0] / loss_and_count[1] + else: + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) elif self.loss_type == 'dr_grpo': loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) else: raise ValueError(f'Unknown loss type: {self.loss_type}') - - loss = all_reduce(loss, group=mpu.get_context_parallel_group()) # loss = loss.mean() avg_metric = { 'loss': loss.clone().detach(), @@ -1079,9 +1131,12 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): } reporting_metric = {**avg_reporting_metric, **max_reporting_metric, **min_reporting_metric, **addition_metrics} - # fix megatron-lm bug - # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.12.0/megatron/core/pipeline_parallel/schedules.py#L291 - loss = loss / mpu.get_context_parallel_world_size() + # NOTE: For GRPO, CP loss aggregation is already handled in loss calculation (line 1046-1058) + # by aggregating numerator/denominator before division (Megatron standard pattern). + # DO NOT divide by CP size again here, or loss will be incorrect. + # For other loss types (bnpo), the aggregation is also handled inline (line 1064-1071). + # Only divide by CP size if there's a specific case that needs it (e.g., Megatron-LM bug fix + # for standard cross-entropy loss, but GRPO uses custom loss calculation). return loss, reporting_metric def model_forward(self, model, data_iterator, no_grad=True, per_token=False): @@ -1207,7 +1262,7 @@ def _add_prompt_id_to_inputs(self, inputs: DataType) -> DataType: def get_num_iters_per_step(self): if hasattr(self, '_num_iters_per_step'): return self._num_iters_per_step - # each rollout DP group will generate generation_batch_size / world_size completions + # each rollout DP group will generate generation_batch_size / dp_size completions dp_size = mpu.get_data_parallel_world_size() completions_to_rollout = self.generation_batch_size // dp_size # completions will be repeated num_generations times after diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index fe77b7f97b..27a247102d 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -347,11 +347,14 @@ def _iter_opts(opt): empty_cache() -def log_gpu_memory(prefix: str = ''): +def log_gpu_memory(prefix: str = '', info_once: bool = False): logger = get_logger() - - logger.info(f'{prefix} GPU memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, ' - f'{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved') + log_msg = (f'{prefix} GPU memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, ' + f'{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved') + if info_once: + logger.info_once(log_msg, hash_id=prefix) + else: + logger.info(log_msg) def should_filter_lora_parameter(name: str) -> bool: From bee6925d935dae6819ec030d28ffded8cc4e3cfd Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 7 Nov 2025 16:35:56 +0800 Subject: [PATCH 50/83] fix vllm tp --- swift/megatron/trainers/grpo_trainer.py | 40 +++++++++++-------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 27de3c333c..18226bab3a 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -593,7 +593,6 @@ def _generate_completions(self, batch): log_gpu_memory('after wake up vLLM engine kv_cache') # Step3: Rollout - batch = self.preprocess_rollout_data(batch) outputs: List[RolloutOutput] = self._rollout(batch) # Step4: Sleep to release memory @@ -608,28 +607,6 @@ def _generate_completions(self, batch): return batch - def preprocess_rollout_data(self, batch): - """ - Gather rollout trajectories across the vLLM tensor-parallel (TP) group. - - This method collect the full batch on every rank, then flattens - the nested lists into a single list of samples. - - Args: - batch (list): List of rollout samples local to this TP rank. - - Returns: - list: Flattened list containing all rollout samples from every - rank in the TP group. - """ - if self.vllm_tensor_parallel_size == 1: - return batch - - gathered_batch = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_batch, batch, group=self.vllm_tp_group) - flattened_batch = [p for sublist in gathered_batch for p in sublist] - return flattened_batch - def _rollout(self, batch) -> List[RolloutOutput]: request_config = self._get_request_config() # TODO: server mode @@ -746,7 +723,24 @@ def _server_rollout(self, return outputs def _colocate_rollout(self, batch, request_config: RequestConfig): + if self.vllm_tensor_parallel_size > 1: + local_rank_in_group = torch.distributed.get_rank(group=self.vllm_tp_group) + local_input_length = len(batch) + all_input_lengths = [None] * self.vllm_tensor_parallel_size + torch.distributed.all_gather_object(all_input_lengths, local_input_length, group=self.vllm_tp_group) + + start_idx = sum(all_input_lengths[:local_rank_in_group]) + end_idx = start_idx + all_input_lengths[local_rank_in_group] + + gathered_batch = [None for _ in range(self.vllm_tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_batch, batch, group=self.vllm_tp_group) + batch = [p for sublist in gathered_batch for p in sublist] + outputs: List[RolloutOutput] = self.engine.infer(infer_requests=batch, request_config=request_config) + + if self.vllm_tensor_parallel_size > 1: + outputs = outputs[start_idx:end_idx] + return outputs def _score_completions(self, inputs: DataType) -> torch.Tensor: From f92739ad2b528061ce30559bf5a5ac76c1ad1a7b Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 7 Nov 2025 16:44:17 +0800 Subject: [PATCH 51/83] fix vllm tp distribute outputs twice --- swift/megatron/trainers/grpo_trainer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 18226bab3a..c8cf590544 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -632,12 +632,6 @@ def postprocess_rollout_data(self, batch, outputs): Updated samples with rollout results merged in. """ - if self.vllm_tensor_parallel_size > 1: - local_rank_in_group = torch.distributed.get_rank(group=self.vllm_tp_group) - orig_size = len(outputs) // self.vllm_tensor_parallel_size - tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - outputs = outputs[tp_slice] - def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], output: RolloutOutput): response = output.response choice = response.choices[0] From bd3f9c94c3f3e01a1fe112dc9c44b36dffbd8f3d Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 7 Nov 2025 17:05:08 +0800 Subject: [PATCH 52/83] length context for template encode --- swift/megatron/trainers/grpo_trainer.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index c8cf590544..10dd17ecdc 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -494,8 +494,9 @@ def _generate_and_score_completions(self, batch): def _get_encoded_batch(rollout_batch, advantages): template = self.template - encoded_batch = [template.encode(data, return_length=True) for data in rollout_batch] - encoded_batch = to_device(template.data_collator(encoded_batch), self.device) + with self._template_context(template): + encoded_batch = [template.encode(data, return_length=True) for data in rollout_batch] + encoded_batch = to_device(template.data_collator(encoded_batch), self.device) labels = encoded_batch['labels'] assert self.template.padding_free position_ids = encoded_batch.get('text_position_ids') @@ -1283,3 +1284,13 @@ def get_local_rollout_batch(self, batch): data_slice = slice(rollout_rank * per_device_batch_size, (rollout_rank + 1) * per_device_batch_size) rollout_batch = global_rollout_batch[data_slice] return rollout_batch + + @contextmanager + def _template_context(self, template: Template): + # The max_length for prompt and completion has already been restricted, so there is no need for max_length here. + max_length = template.max_length + template.max_length = None + try: + yield + finally: + template.max_length = max_length From 3f9b11d40b55476f60e5f500a36feaa62bc67538 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Sun, 9 Nov 2025 18:54:01 +0800 Subject: [PATCH 53/83] fix sequence is wip" --- swift/megatron/trainers/grpo_trainer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 10dd17ecdc..b1bc47557d 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -1043,6 +1043,13 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): coef_1 = torch.clamp(coef_1, max=self.args.delta) if self.template.padding_free: + # In padding_free + sequence mode, coef_1 is [num_samples, 1] + # We need to expand to [1, total_tokens] for token-level loss computation + if self.importance_sampling_level == 'sequence': + # Expand sequence-level weights to token-level without gradient + coef_1 = torch.repeat_interleave(coef_1.squeeze(-1), lengths, dim=0).unsqueeze(0) + coef_2 = torch.repeat_interleave(coef_2.squeeze(-1), lengths, dim=0).unsqueeze(0) + advantages = advantages[-coef_1.shape[1]:] per_token_loss1 = coef_1 * advantages.unsqueeze(0) per_token_loss2 = coef_2 * advantages.unsqueeze(0) From 1b4fd804bbebe3ac3f67c9b9fd8369355eadb875 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 10 Nov 2025 15:51:46 +0800 Subject: [PATCH 54/83] fix padding loss calculate --- swift/llm/dataset/dataset/llm.py | 7 +++++++ swift/megatron/trainers/grpo_trainer.py | 21 ++++++++++++++------- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/swift/llm/dataset/dataset/llm.py b/swift/llm/dataset/dataset/llm.py index 2cba486e3f..588a4ee621 100644 --- a/swift/llm/dataset/dataset/llm.py +++ b/swift/llm/dataset/dataset/llm.py @@ -925,3 +925,10 @@ def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: ], dataset_name='self-cognition', tags=['chat', 'self-cognition', '🔥'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='open-r1/DAPO-Math-17k-Processed', + hf_dataset_id='open-r1/DAPO-Math-17k-Processed', + subsets=['all'], + tags=['math', 'rlvr'])) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index b1bc47557d..518afb84db 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -16,6 +16,7 @@ from accelerate.utils import broadcast_object_list from dacite import from_dict from megatron.core import mpu +from megatron.core.rerun_state_machine import RerunDataIterator from megatron.training import get_args, training from torch.distributed.nn import all_reduce from trl.trainer.grpo_trainer import nanstd @@ -474,7 +475,7 @@ def _replace_data_iterator(self, data_iterator, model): self._buffered_inputs = mini_batch_data self._step += 1 inputs = self._buffered_inputs[self._step % self.steps_per_generation] - return iter(inputs) + return RerunDataIterator(iter(inputs)) def _generate_and_score_completions(self, batch): # Get or create the rollout group (TP×PP×CP) @@ -1024,13 +1025,15 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): elif self.importance_sampling_level in ['sequence', 'sequence_token']: log_ratio_list = torch.split(log_ratio.squeeze(0), lengths_with_padding.tolist()) mask_list = torch.split(completion_mask.squeeze(0), lengths_with_padding.tolist()) - seq_weights = [(lr * m).sum() / m.sum().clamp(min=1.0) for lr, m in zip(log_ratio_list, mask_list)] + seq_weights = [(lr * m).sum() / m.sum().clamp(min=1.0) + for lr, m in zip(log_ratio_list[:micro_batch_size], mask_list[:micro_batch_size])] seq_level_log_weights = torch.stack(seq_weights).to(log_ratio.dtype).unsqueeze(-1) if self.importance_sampling_level == 'sequence': log_importance_weights = seq_level_log_weights else: seq_level_log_weight = seq_level_log_weights.detach() - seq_level_log_weight = torch.repeat_interleave(seq_level_log_weight, lengths).unsqueeze(0) + seq_level_log_weight = torch.repeat_interleave(seq_level_log_weight, + completion_mask.shape[1]).unsqueeze(0) log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight else: raise ValueError( @@ -1047,15 +1050,16 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): # We need to expand to [1, total_tokens] for token-level loss computation if self.importance_sampling_level == 'sequence': # Expand sequence-level weights to token-level without gradient - coef_1 = torch.repeat_interleave(coef_1.squeeze(-1), lengths, dim=0).unsqueeze(0) - coef_2 = torch.repeat_interleave(coef_2.squeeze(-1), lengths, dim=0).unsqueeze(0) + coef_1 = torch.repeat_interleave(coef_1.squeeze(-1), completion_mask.shape[1], dim=0).unsqueeze(0) + coef_2 = torch.repeat_interleave(coef_2.squeeze(-1), completion_mask.shape[1], dim=0).unsqueeze(0) advantages = advantages[-coef_1.shape[1]:] per_token_loss1 = coef_1 * advantages.unsqueeze(0) per_token_loss2 = coef_2 * advantages.unsqueeze(0) else: - per_token_loss1 = coef_1 * advantages.unsqueeze(1) - per_token_loss2 = coef_2 * advantages.unsqueeze(1) + raise NotImplementedError + # per_token_loss1 = coef_1 * advantages.unsqueeze(1) + # per_token_loss2 = coef_2 * advantages.unsqueeze(1) per_token_loss = -torch.min(per_token_loss1, per_token_loss2) if self.beta != 0.0: per_token_loss = per_token_loss + self.beta * per_token_kl @@ -1093,6 +1097,9 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) elif self.loss_type == 'dr_grpo': loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) + if self.args.context_parallel_size > 1: + all_reduce(loss, group=mpu.get_context_parallel_group()) + loss = loss / self.args.context_parallel_size else: raise ValueError(f'Unknown loss type: {self.loss_type}') # loss = loss.mean() From e9a307dec8c8a949be8900bbc11230724f9900c5 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 10 Nov 2025 19:58:18 +0800 Subject: [PATCH 55/83] log completions --- swift/megatron/argument/megatron_args.py | 37 +++--- swift/megatron/trainers/grpo_trainer.py | 137 +++++++++++++++++------ swift/megatron/trainers/utils.py | 24 +++- 3 files changed, 140 insertions(+), 58 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 3597b3ed9c..b6ec1be97a 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -35,7 +35,10 @@ class RLHFMegatronArgumentsMixin: undesirable_weight: float = 1. calculate_KL: Optional[bool] = None - # =========================== GRPO =========================== + # rm + center_rewards_coefficient: Optional[float] = None + + # grpo generation_batch_size: Optional[int] = None steps_per_generation: Optional[int] = None num_generations: int = 8 @@ -43,17 +46,15 @@ class RLHFMegatronArgumentsMixin: # GSPO https://www.arxiv.org/abs/2507.18071 importance_sampling_level: Literal['token', 'sequence', 'sequence_token'] = 'token' - # ─────────────────────────── Sampling ─────────────────────────── epsilon: float = 0.2 epsilon_high: Optional[float] = None delta: Optional[float] = None top_k: int = 50 top_p: float = 0.9 repetition_penalty: float = 1. - # ─────────────────────────── VLLM ─────────────────────────── use_vllm: bool = False vllm_mode: Literal['server', 'colocate'] = 'colocate' - # ────────────── Internal VLLM (colocate) ────────────── + vllm_enable_prefix_caching: bool = True vllm_gpu_memory_utilization: float = 0.9 vllm_tensor_parallel_size: int = 1 @@ -63,13 +64,11 @@ class RLHFMegatronArgumentsMixin: vllm_disable_cascade_attn: bool = False sleep_level: Literal[0, 1, 2] = 0 - # ────────────── External VLLM (server, not supported yet) ────────────── vllm_server_base_url: Optional[List[str]] = None vllm_server_host: Optional[List[str]] = None vllm_server_port: List[int] = field(default_factory=lambda: [8000]) vllm_server_timeout: float = 240.0 - # ─────────────────────────── Reward ─────────────────────────── reward_funcs: List[str] = field(default_factory=list) reward_weights: List[float] = None # see details in swift/plugin/orm.py @@ -85,6 +84,16 @@ class RLHFMegatronArgumentsMixin: # soft_overlong, https://arxiv.org/abs/2503.14476 soft_max_length: Optional[int] = None soft_cache_length: Optional[int] = None + # DAPO, https://arxiv.org/abs/2503.14476 + dynamic_sample: bool = False + max_resample_times: int = 3 + overlong_filter: bool = False + + # Dr. GRPO, https://arxiv.org/abs/2503.20783 + scale_rewards: bool = True + + wandb_log_unique_prompts: Optional[bool] = None + log_completions: bool = False # ─────────────────────────── Not Supported Yet ─────────────────────────── # reward model @@ -100,37 +109,23 @@ class RLHFMegatronArgumentsMixin: move_model_batches: Optional[int] = None offload_optimizer: bool = False offload_model: bool = False - gc_collect_after_offload: bool = False # deprecated # multi turn - multi_turn_func: Optional[str] = None # deprecated multi_turn_scheduler: Optional[str] = None max_turns: Optional[int] = None completion_length_limit_scope: Literal['total', 'per_round'] = 'per_round' vllm_server_pass_dataset: bool = False - # DAPO, https://arxiv.org/abs/2503.14476 - dynamic_sample: bool = False - max_resample_times: int = 3 - overlong_filter: bool = False - - # Dr. GRPO, https://arxiv.org/abs/2503.20783 - scale_rewards: bool = True - # entropy log_entropy: bool = False # Beyond the 80/20 Rule, https://arxiv.org/abs/2506.01939 top_entropy_quantile: float = 1.0 - wandb_log_unique_prompts: Optional[bool] = None num_iterations: int = 1 # dataset dataset_shuffle: Optional[bool] = True - # rm - center_rewards_coefficient: Optional[float] = None - def _init_kto(self): if self.calculate_KL is None: # Not all losses require a KL calculation @@ -395,7 +390,7 @@ class MegatronArguments(ExtraMegatronArguments): no_load_rng: bool = False finetune: bool = False ckpt_format: Literal['torch', 'torch_dist', 'zarr'] = 'torch_dist' - no_initialization: bool = False + no_initialization: bool = True auto_detect_ckpt_format: bool = True exit_on_missing_checkpoint: bool = True async_save: bool = False diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 518afb84db..7b7de1f8a0 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -11,31 +11,36 @@ from typing import Any, Dict, List, Tuple, Union import json +import pandas as pd import torch import torch.nn as nn from accelerate.utils import broadcast_object_list from dacite import from_dict from megatron.core import mpu from megatron.core.rerun_state_machine import RerunDataIterator -from megatron.training import get_args, training +from megatron.training import get_args, get_wandb_writer, training from torch.distributed.nn import all_reduce from trl.trainer.grpo_trainer import nanstd from vllm.distributed import parallel_state as vllm_ps from swift.llm import RequestConfig, RolloutInferRequest, RowPreprocessor, Template, to_device from swift.llm.infer.protocol import RolloutOutput +from swift.llm.template.template_inputs import TemplateInputs from swift.plugin import MultiTurnScheduler, multi_turns, orms from swift.trainers.rlhf_trainer.grpo_trainer import DataType -from swift.trainers.rlhf_trainer.utils import (FlattenedTensorBucket, _create_parameter_buckets, - _process_bucket_with_flattened_tensor, aggressive_empty_cache, +from swift.trainers.rlhf_trainer.utils import (FlattenedTensorBucket, aggressive_empty_cache, replace_assistant_response_with_ids, set_expandable_segments) -from swift.utils import get_current_device, get_logger, is_master, is_vllm_available, remove_response +from swift.utils import (get_current_device, get_logger, is_last_rank, is_vllm_available, is_wandb_available, + remove_response) from ..argument import MegatronArguments, MegatronRLHFArguments from ..utils import forward_step_helper from .rlhf_mixin import MegatronRLHFTrainer from .utils import (gather, gather_object, get_swift_datasets_provider, load_megatron_model_to_gpu, load_megatron_optimizer, log_gpu_memory, offload_megatron_model_to_cpu, offload_megatron_optimizer, - profiling_context) + patch_profiling_context, patch_profiling_decorator, profiling_context) + +if is_wandb_available(): + import wandb try: from trl.trainer.utils import entropy_from_logits @@ -53,6 +58,7 @@ def __init__(self, args: MegatronRLHFArguments, template: Template, **kwargs): self.args = args self.hf_model_dir = args.model_info.model_dir self.processing_class = self.template.processor + self._prepare_metrics() self._prepare_template_data_collator() self._init_grpo_params() self._prepare_rewards() @@ -88,7 +94,7 @@ def _init_grpo_params(self): # distributed params self.world_size = torch.distributed.get_world_size() self.process_index = torch.distributed.get_rank() - self.is_main_process = is_master() + self.is_main_process = is_last_rank() self.device = get_current_device() # algorithm params self.num_generations = args.num_generations # G in the GRPO paper @@ -140,7 +146,7 @@ def _prepare_rollout_engine(self): self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode self.use_vllm = args.use_vllm self.async_generate = args.async_generate - self.use_fast_infer = self.use_vllm # whether to use the PT backend + self.use_fast_infer = self.use_vllm self.vllm_use_async_engine = False self.enable_offload = False self.use_gym_env = False @@ -200,6 +206,7 @@ def prepare_vllm(self): self._buffered_inputs = None return engine + @patch_profiling_decorator def _move_model_to_vllm(self): # Handle LoRA: merge adapters before exporting weights is_lora_training = self.args.train_type == 'lora' @@ -209,7 +216,7 @@ def _move_model_to_vllm(self): self.merge_lora_adapters() # Export and load weights incrementally to avoid memory spikes - self._export_and_load_weights_incrementally() + self._export_and_load_weights() finally: # Unmerge adapters to restore training state @@ -222,7 +229,13 @@ def _move_model_to_vllm(self): elif self.vllm_mode == 'colocate': self.engine.engine.reset_prefix_cache() - def _export_and_load_weights_incrementally(self): + @property + def bridge(self): + if self._bridge is None: + self._bridge = self.args.megatron_model_meta.bridge_cls(disable_tqmd=True) + return self._bridge + + def _export_and_load_weights(self): """ Export weights from Megatron models and load to vLLM incrementally. @@ -230,7 +243,8 @@ def _export_and_load_weights_incrementally(self): For server mode: Process weights in buckets to avoid memory spikes. """ # Export weights returns an iterator - weight_iterator = self.bridge.export_weights(self.unwrapped_models) + with patch_profiling_context(self, 'export_weights'): + weight_iterator = self.bridge.export_weights(self.unwrapped_models) if self.vllm_mode == 'colocate': # Colocate mode: load_weights supports iterator, pass directly @@ -549,11 +563,13 @@ def _get_encoded_batch(rollout_batch, advantages): micro_batch_data = self._maybe_replace_response_token(micro_batch_data) micro_batch_advantages = total_advantages[idx:idx + self.micro_batch_size] micro_batch_data = _get_encoded_batch(micro_batch_data, micro_batch_advantages) - micro_batch_data = self._maybe_compute_logps(micro_batch_data) + with patch_profiling_context(self, 'compute_ref_old_logps'): + micro_batch_data = self._maybe_compute_logps(micro_batch_data) mini_batch_data.append(micro_batch_data) return mini_batch_data + @patch_profiling_decorator def _generate_completions(self, batch): """ Generate completions for a batch of rollout data using vLLM engine. @@ -616,6 +632,12 @@ def _rollout(self, batch) -> List[RolloutOutput]: rollout_outputs = self._server_rollout(batch, request_config) elif self.vllm_mode == 'colocate': rollout_outputs = self._colocate_rollout(batch, request_config) + # log prompt and completions + messages = gather_object([data['messages'] for data in batch]) + completions = gather_object([data.response.choices[0].message.content for data in rollout_outputs]) + self._logs['prompt'].extend(self._apply_chat_template_to_messages_list(messages)) + self._logs['completion'].extend(completions) + return rollout_outputs def postprocess_rollout_data(self, batch, outputs): @@ -739,6 +761,7 @@ def _colocate_rollout(self, batch, request_config: RequestConfig): return outputs + @patch_profiling_decorator def _score_completions(self, inputs: DataType) -> torch.Tensor: """Score completions using all reward functions. @@ -827,7 +850,10 @@ def log_rewards_metrics(rewards: torch.Tensor, rewards_per_func_for_metrics: tor self._metrics[mode][f'rewards/{name}/mean'].append(torch.nanmean(col).item()) self._metrics[mode][f'rewards/{name}/std'].append(nanstd(col).item()) - log_rewards_metrics(rewards=grouped_rewards, rewards_per_func_for_metrics=rewards_per_func) + log_rewards_metrics(rewards=grouped_rewards, rewards_per_func_for_metrics=total_rewards_per_func) + self._logs['advantages'].extend(advantages.tolist()) + for i, name in enumerate(self.reward_func_names): + self._logs['rewards'][name].extend(total_rewards_per_func[:, i].tolist()) slice_start = self.process_index * len(batch) slice_end = slice_start + len(batch) @@ -972,6 +998,7 @@ def build_pretraining_data_loader(*_args, **kwargs): finally: training.build_pretraining_data_loader = origin_build_pretraining_data_loader + @patch_profiling_decorator def forward_step(self, data_iterator, model): # train_batch_size # return: output_tensor, loss_func @@ -987,6 +1014,7 @@ def forward_step(self, data_iterator, model): output_tensor = model(**inputs) return output_tensor, partial(self.loss_func, data=data) + @patch_profiling_decorator def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): advantages = data['advantages'] labels = data['labels'] @@ -1010,6 +1038,9 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): if self.args.overlong_filter and truncated_mask.any(): completion_mask = completion_mask & (~truncated_mask) + if not completion_mask.any(): + logger.warning('All completions are truncated in this batch. Loss and grad_norm will be 0. ' + 'Consider increasing max_completion_length') if self.beta != 0.0: ref_per_token_logps = data.get('ref_per_token_logps') @@ -1053,7 +1084,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): coef_1 = torch.repeat_interleave(coef_1.squeeze(-1), completion_mask.shape[1], dim=0).unsqueeze(0) coef_2 = torch.repeat_interleave(coef_2.squeeze(-1), completion_mask.shape[1], dim=0).unsqueeze(0) - advantages = advantages[-coef_1.shape[1]:] + advantages = advantages[:coef_1.shape[1]] per_token_loss1 = coef_1 * advantages.unsqueeze(0) per_token_loss2 = coef_2 * advantages.unsqueeze(0) else: @@ -1105,41 +1136,51 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): # loss = loss.mean() avg_metric = { 'loss': loss.clone().detach(), - 'completions/mean_length': lengths.float().mean(), - } - max_metric = { - 'completions/max_length': lengths.float().max(), - } - min_metric = { - 'completions/min_length': lengths.float().min(), } + custom_metrics = {} + if self.args.context_parallel_size == 1: + total_lengths = gather(lengths) + custom_metrics = { + 'completions/mean_length': total_lengths.float().mean(), + 'completions/max_length': total_lengths.float().max(), + 'completions/min_length': total_lengths.float().min(), + } + if self.beta != 0.0: avg_metric['kl'] = per_token_kl.mean().item() avg_reporting_metric = loss.new_tensor(list(avg_metric.values())) - max_reporting_metric = loss.new_tensor(list(max_metric.values())) - min_reporting_metric = loss.new_tensor(list(min_metric.values())) torch.distributed.all_reduce( avg_reporting_metric, torch.distributed.ReduceOp.AVG, group=mpu.get_data_parallel_group()) - torch.distributed.all_reduce( - max_reporting_metric, torch.distributed.ReduceOp.MAX, group=mpu.get_data_parallel_group()) - torch.distributed.all_reduce( - min_reporting_metric, torch.distributed.ReduceOp.MIN, group=mpu.get_data_parallel_group()) avg_reporting_metric = {k: avg_reporting_metric[i] for i, k in enumerate(avg_metric.keys())} - max_reporting_metric = {k: max_reporting_metric[i] for i, k in enumerate(max_metric.keys())} - min_reporting_metric = {k: min_reporting_metric[i] for i, k in enumerate(min_metric.keys())} + mode = 'train' if self.unwrapped_models[0].training else 'eval' addition_metrics = { key: torch.tensor(sum(val) / len(val), device=loss.device) - for key, val in self._metrics['train'].items() + for key, val in self._metrics[mode].items() } - reporting_metric = {**avg_reporting_metric, **max_reporting_metric, **min_reporting_metric, **addition_metrics} - # NOTE: For GRPO, CP loss aggregation is already handled in loss calculation (line 1046-1058) - # by aggregating numerator/denominator before division (Megatron standard pattern). - # DO NOT divide by CP size again here, or loss will be incorrect. - # For other loss types (bnpo), the aggregation is also handled inline (line 1064-1071). - # Only divide by CP size if there's a specific case that needs it (e.g., Megatron-LM bug fix - # for standard cross-entropy loss, but GRPO uses custom loss calculation). + reporting_metric = {**avg_reporting_metric, **addition_metrics, **custom_metrics} + # log_completions + if self.log_completions and self.is_main_process and self._step % self.steps_per_generation == 0: + table = { + 'gen_step': [self._step] * len(self._logs['prompt']), + 'prompt': list(self._logs['prompt']), + 'completion': list(self._logs['completion']), + **{k: list(v) + for k, v in self._logs['rewards'].items()}, + 'advantages': list(self._logs['advantages']), + } + self.jsonl_writer.append(table) + wandb_writer = get_wandb_writer() + if wandb_writer: + df = pd.DataFrame(table) + if self.wandb_log_unique_prompts: + df = df.drop_duplicates(subset=['prompt']) + if not self.init_custom_metric: + wandb_writer.define_metric('completions', step_metric='gen_step') + self.init_custom_metric = True + wandb_writer.log({'completions': wandb.Table(dataframe=df)}, self._step) + return loss, reporting_metric def model_forward(self, model, data_iterator, no_grad=True, per_token=False): @@ -1308,3 +1349,27 @@ def _template_context(self, template: Template): yield finally: template.max_length = max_length + + def _prepare_metrics(self): + args = self.args + from swift.utils import JsonlWriter + from collections import deque + self.log_completions = args.log_completions + self.wandb_log_unique_prompts = args.wandb_log_unique_prompts + self.jsonl_writer = JsonlWriter(os.path.join(args.save, 'completions.jsonl')) + self.init_custom_metric = False + self._logs = { + 'prompt': deque(maxlen=args.generation_batch_size), + 'completion': deque(maxlen=args.generation_batch_size), + 'rewards': defaultdict(lambda: deque(maxlen=args.generation_batch_size)), + 'advantages': deque(maxlen=args.generation_batch_size), + } + + def _apply_chat_template_to_messages_list(self, messages_list: DataType): + prompts_text = [] + for messages in messages_list: + remove_response(messages) + template_inputs = TemplateInputs.from_dict({'messages': messages}) + res = self.template.encode(template_inputs) + prompts_text.append(self.template.safe_decode(res['input_ids'])) + return prompts_text diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 114665f089..a19fe5caf0 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import functools import gc import time from contextlib import contextmanager @@ -135,7 +136,28 @@ def profiling_context(trainer, name: str): if wandb_writer and trainer.is_main_process: wandb_writer.log(profiling_metrics) - # TODO: add swanlab support + +@contextmanager +def patch_profiling_context(trainer, name: str): + start_time = time.perf_counter() + yield + end_time = time.perf_counter() + duration = end_time - start_time + + profiling_metrics = {f'profiling/Time taken: {trainer.__class__.__name__}.{name}': duration} + wandb_writer = get_wandb_writer() + if wandb_writer and trainer.is_main_process: + wandb_writer.log(profiling_metrics) + + +def patch_profiling_decorator(func): + + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + with patch_profiling_context(self, func.__name__): + return func(self, *args, **kwargs) + + return wrapper def gather(tensor, group: Optional[torch.distributed.ProcessGroup] = None): From 8b1e40767a41bc9c288fded2bc4af3bc563c1b4a Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 10 Nov 2025 21:06:18 +0800 Subject: [PATCH 56/83] bug todo --- swift/megatron/trainers/grpo_trainer.py | 27 ++++++++++++++++--------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 7b7de1f8a0..22b0ae1658 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -19,6 +19,7 @@ from megatron.core import mpu from megatron.core.rerun_state_machine import RerunDataIterator from megatron.training import get_args, get_wandb_writer, training +from torch._tensor import Tensor from torch.distributed.nn import all_reduce from trl.trainer.grpo_trainer import nanstd from vllm.distributed import parallel_state as vllm_ps @@ -1063,8 +1064,10 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): log_importance_weights = seq_level_log_weights else: seq_level_log_weight = seq_level_log_weights.detach() - seq_level_log_weight = torch.repeat_interleave(seq_level_log_weight, - completion_mask.shape[1]).unsqueeze(0) + seq_level_log_weight = torch.cat([ + torch.repeat_interleave(log_weight, length) + for log_weight, length in zip(seq_level_log_weight, lengths.tolist()) + ]) log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight else: raise ValueError( @@ -1081,10 +1084,14 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): # We need to expand to [1, total_tokens] for token-level loss computation if self.importance_sampling_level == 'sequence': # Expand sequence-level weights to token-level without gradient - coef_1 = torch.repeat_interleave(coef_1.squeeze(-1), completion_mask.shape[1], dim=0).unsqueeze(0) - coef_2 = torch.repeat_interleave(coef_2.squeeze(-1), completion_mask.shape[1], dim=0).unsqueeze(0) - - advantages = advantages[:coef_1.shape[1]] + coef_1 = torch.cat([ + torch.repeat_interleave(log_weight, length) for log_weight, length in zip(coef_1, lengths.tolist()) + ]) + coef_2 = torch.cat([ + torch.repeat_interleave(log_weight, length) for log_weight, length in zip(coef_2, lengths.tolist()) + ]) + + # advantages = advantages[:coef_1.shape[1]] per_token_loss1 = coef_1 * advantages.unsqueeze(0) per_token_loss2 = coef_2 * advantages.unsqueeze(0) else: @@ -1176,10 +1183,10 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): df = pd.DataFrame(table) if self.wandb_log_unique_prompts: df = df.drop_duplicates(subset=['prompt']) - if not self.init_custom_metric: - wandb_writer.define_metric('completions', step_metric='gen_step') - self.init_custom_metric = True - wandb_writer.log({'completions': wandb.Table(dataframe=df)}, self._step) + # if not self.init_custom_metric: + # wandb_writer.define_metric('completions', step_metric='gen_step') + # self.init_custom_metric = True + wandb_writer.log({'completions': wandb.Table(dataframe=df)}) return loss, reporting_metric From b6c10ebdeeedc1ffc8d10d2160fb37d62a91211b Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 10 Nov 2025 23:33:03 +0800 Subject: [PATCH 57/83] fix wip --- swift/megatron/trainers/grpo_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 22b0ae1658..15998a83b0 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -1067,7 +1067,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): seq_level_log_weight = torch.cat([ torch.repeat_interleave(log_weight, length) for log_weight, length in zip(seq_level_log_weight, lengths.tolist()) - ]) + ]).unsqueeze(0) log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight else: raise ValueError( @@ -1086,12 +1086,12 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): # Expand sequence-level weights to token-level without gradient coef_1 = torch.cat([ torch.repeat_interleave(log_weight, length) for log_weight, length in zip(coef_1, lengths.tolist()) - ]) + ]).unsqueeze(0) coef_2 = torch.cat([ torch.repeat_interleave(log_weight, length) for log_weight, length in zip(coef_2, lengths.tolist()) - ]) + ]).unsqueeze(0) - # advantages = advantages[:coef_1.shape[1]] + advantages = advantages[-coef_1.shape[1]:] per_token_loss1 = coef_1 * advantages.unsqueeze(0) per_token_loss2 = coef_2 * advantages.unsqueeze(0) else: From 46b8e2a2a69b435e105cd030677cb26afbfc750f Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 12 Nov 2025 09:30:26 +0800 Subject: [PATCH 58/83] fix sp padding --- swift/megatron/trainers/grpo_trainer.py | 51 +++++++++++++++---------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 15998a83b0..56ce829141 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -1057,17 +1057,18 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): elif self.importance_sampling_level in ['sequence', 'sequence_token']: log_ratio_list = torch.split(log_ratio.squeeze(0), lengths_with_padding.tolist()) mask_list = torch.split(completion_mask.squeeze(0), lengths_with_padding.tolist()) - seq_weights = [(lr * m).sum() / m.sum().clamp(min=1.0) - for lr, m in zip(log_ratio_list[:micro_batch_size], mask_list[:micro_batch_size])] - seq_level_log_weights = torch.stack(seq_weights).to(log_ratio.dtype).unsqueeze(-1) + # Optimized: compute weighted sum for each sequence (avoid list comprehension overhead) + # Use torch.stack on results instead of intermediate lists + seq_weights = torch.stack([(lr * m).sum() / m.sum().clamp(min=1.0) + for lr, m in zip(log_ratio_list, mask_list)]) + seq_level_log_weights = seq_weights.to(log_ratio.dtype).unsqueeze(-1) if self.importance_sampling_level == 'sequence': log_importance_weights = seq_level_log_weights else: seq_level_log_weight = seq_level_log_weights.detach() - seq_level_log_weight = torch.cat([ - torch.repeat_interleave(log_weight, length) - for log_weight, length in zip(seq_level_log_weight, lengths.tolist()) - ]).unsqueeze(0) + # Vectorized: use repeat_interleave with tensor directly + seq_level_log_weight = torch.repeat_interleave( + seq_level_log_weight.squeeze(-1), lengths_with_padding, dim=0).unsqueeze(0) log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight else: raise ValueError( @@ -1083,13 +1084,9 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): # In padding_free + sequence mode, coef_1 is [num_samples, 1] # We need to expand to [1, total_tokens] for token-level loss computation if self.importance_sampling_level == 'sequence': - # Expand sequence-level weights to token-level without gradient - coef_1 = torch.cat([ - torch.repeat_interleave(log_weight, length) for log_weight, length in zip(coef_1, lengths.tolist()) - ]).unsqueeze(0) - coef_2 = torch.cat([ - torch.repeat_interleave(log_weight, length) for log_weight, length in zip(coef_2, lengths.tolist()) - ]).unsqueeze(0) + # Vectorized: expand sequence-level weights to token-level without gradient + coef_1 = torch.repeat_interleave(coef_1.squeeze(-1), lengths_with_padding, dim=0).unsqueeze(0) + coef_2 = torch.repeat_interleave(coef_2.squeeze(-1), lengths_with_padding, dim=0).unsqueeze(0) advantages = advantages[-coef_1.shape[1]:] per_token_loss1 = coef_1 * advantages.unsqueeze(0) @@ -1108,7 +1105,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): # In CP mode, aggregate numerator and denominator before division (Megatron standard) if self.args.context_parallel_size > 1: - # Compute sum and count for each sample on this CP rank + # Optimized: compute sum and count for each sample, then stack results sample_sum_and_count = torch.stack([ torch.stack([(loss * mask).sum(), mask.sum().clamp(min=1.0)]) for loss, mask in zip(loss_list[:micro_batch_size], mask_list[:micro_batch_size]) @@ -1121,9 +1118,12 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): sample_loss = sample_sum_and_count[:, 0] / sample_sum_and_count[:, 1] loss = sample_loss.mean() else: - sample_loss = [(loss * mask).sum() / mask.sum().clamp(min=1.0) - for loss, mask in zip(loss_list[:micro_batch_size], mask_list[:micro_batch_size])] - loss = torch.stack(sample_loss).mean() + # Optimized: compute sample loss, then stack results + sample_loss = torch.stack([ + (loss * mask).sum() / mask.sum().clamp(min=1.0) + for loss, mask in zip(loss_list[:micro_batch_size], mask_list[:micro_batch_size]) + ]) + loss = sample_loss.mean() elif self.loss_type == 'bnpo': if self.args.context_parallel_size > 1: # Aggregate numerator and denominator across CP ranks @@ -1153,8 +1153,8 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): 'completions/min_length': total_lengths.float().min(), } - if self.beta != 0.0: - avg_metric['kl'] = per_token_kl.mean().item() + # if self.beta != 0.0: + # avg_metric['kl'] = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0).item() avg_reporting_metric = loss.new_tensor(list(avg_metric.values())) torch.distributed.all_reduce( avg_reporting_metric, torch.distributed.ReduceOp.AVG, group=mpu.get_data_parallel_group()) @@ -1371,6 +1371,17 @@ def _prepare_metrics(self): 'rewards': defaultdict(lambda: deque(maxlen=args.generation_batch_size)), 'advantages': deque(maxlen=args.generation_batch_size), } + if is_wandb_available(): + from wandb.sdk.wandb_run import Run + origin_log = Run.log + from functools import wraps + + @wraps(origin_log) + def log(self, data: dict[str, Any], step: int | None = None, commit: bool | None = None): + return origin_log(self, data, None, commit) + + # Directly replace the class method, no need for MethodType + Run.log = log def _apply_chat_template_to_messages_list(self, messages_list: DataType): prompts_text = [] From 87bbad7451ea28183bff6bc0887388469ab91c2b Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 12 Nov 2025 11:11:06 +0800 Subject: [PATCH 59/83] fix pp --- swift/megatron/trainers/grpo_trainer.py | 37 +++++++++++++------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 56ce829141..50423bc4f7 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -1145,28 +1145,29 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): 'loss': loss.clone().detach(), } custom_metrics = {} - if self.args.context_parallel_size == 1: - total_lengths = gather(lengths) - custom_metrics = { - 'completions/mean_length': total_lengths.float().mean(), - 'completions/max_length': total_lengths.float().max(), - 'completions/min_length': total_lengths.float().min(), - } + total_lengths = gather(lengths, group=mpu.get_data_parallel_group(with_context_parallel=True)) + total_lengths *= self.args.context_parallel_size + custom_metrics = { + 'completions/mean_length': total_lengths.float().mean(), + 'completions/max_length': total_lengths.float().max(), + 'completions/min_length': total_lengths.float().min(), + } - # if self.beta != 0.0: - # avg_metric['kl'] = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0).item() - avg_reporting_metric = loss.new_tensor(list(avg_metric.values())) - torch.distributed.all_reduce( - avg_reporting_metric, torch.distributed.ReduceOp.AVG, group=mpu.get_data_parallel_group()) + if self.beta != 0.0: + avg_metric['kl'] = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0).item() - avg_reporting_metric = {k: avg_reporting_metric[i] for i, k in enumerate(avg_metric.keys())} mode = 'train' if self.unwrapped_models[0].training else 'eval' - addition_metrics = { - key: torch.tensor(sum(val) / len(val), device=loss.device) - for key, val in self._metrics[mode].items() - } + if self._metrics[mode]: + addition_metrics = { + key: torch.tensor(sum(val) / len(val), device=loss.device) + for key, val in self._metrics[mode].items() + } + avg_metric.update(addition_metrics) + + avg_metric = self._all_reduce_metric(avg_metric) + + reporting_metric = {**avg_metric, **custom_metrics} - reporting_metric = {**avg_reporting_metric, **addition_metrics, **custom_metrics} # log_completions if self.log_completions and self.is_main_process and self._step % self.steps_per_generation == 0: table = { From 19277121c9d250a8f88f539d1bd519ef7af4298e Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 12 Nov 2025 16:15:40 +0800 Subject: [PATCH 60/83] revert to full seq loss for cp --- swift/megatron/trainers/grpo_trainer.py | 53 +++++----------------- swift/megatron/trainers/rlhf_mixin.py | 60 +++++++++++++++++++++++++ swift/megatron/trainers/utils.py | 40 ----------------- 3 files changed, 71 insertions(+), 82 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 50423bc4f7..8f6c3d67dd 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -1023,17 +1023,12 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): packed_seq_params = data['packed_seq_params'] truncated_mask = data['truncated_mask'] micro_batch_size = self.micro_batch_size + # Use full sequence lengths directly (get_logps returns full sequences in CP mode) lengths = packed_seq_params.cu_seqlens_q[1:micro_batch_size + 1] - packed_seq_params.cu_seqlens_q[:micro_batch_size] lengths_with_padding = packed_seq_params.cu_seqlens_q[1:] - packed_seq_params.cu_seqlens_q[:-1] - if mpu.get_context_parallel_world_size() > 1: - # When using Context Parallel, each rank only processes a portion of the sequence - # So we need to divide the lengths by CP size - cp_size = mpu.get_context_parallel_world_size() - cu_seqlens_cp = packed_seq_params.cu_seqlens_q // cp_size - lengths_with_padding = cu_seqlens_cp[1:] - cu_seqlens_cp[:-1] - lengths = cu_seqlens_cp[1:micro_batch_size + 1] - cu_seqlens_cp[:micro_batch_size] + # get_logps with per_token=True now returns full sequences (all_gather in CP mode) per_token_logps = self.get_logps( output_tensor, labels, packed_seq_params, packed_seq_params.num_samples, per_token=True) @@ -1103,50 +1098,22 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): loss_list = torch.split(per_token_loss.squeeze(0), lengths_with_padding.tolist()) mask_list = torch.split(completion_mask.squeeze(0), lengths_with_padding.tolist()) - # In CP mode, aggregate numerator and denominator before division (Megatron standard) - if self.args.context_parallel_size > 1: - # Optimized: compute sum and count for each sample, then stack results - sample_sum_and_count = torch.stack([ - torch.stack([(loss * mask).sum(), mask.sum().clamp(min=1.0)]) - for loss, mask in zip(loss_list[:micro_batch_size], mask_list[:micro_batch_size]) - ]) # Shape: [micro_batch_size, 2] - - # All-reduce to aggregate across CP ranks - all_reduce(sample_sum_and_count, group=mpu.get_context_parallel_group()) - - # Now compute per-sample loss and average - sample_loss = sample_sum_and_count[:, 0] / sample_sum_and_count[:, 1] - loss = sample_loss.mean() - else: - # Optimized: compute sample loss, then stack results - sample_loss = torch.stack([ - (loss * mask).sum() / mask.sum().clamp(min=1.0) - for loss, mask in zip(loss_list[:micro_batch_size], mask_list[:micro_batch_size]) - ]) - loss = sample_loss.mean() + sample_loss = torch.stack([(loss * mask).sum() / mask.sum().clamp(min=1.0) + for loss, mask in zip(loss_list[:micro_batch_size], mask_list[:micro_batch_size]) + ]) + loss = sample_loss.mean() elif self.loss_type == 'bnpo': - if self.args.context_parallel_size > 1: - # Aggregate numerator and denominator across CP ranks - loss_and_count = torch.stack([(per_token_loss * completion_mask).sum(), - completion_mask.sum().clamp(min=1.0)]) - all_reduce(loss_and_count, group=mpu.get_context_parallel_group()) - loss = loss_and_count[0] / loss_and_count[1] - else: - loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) elif self.loss_type == 'dr_grpo': loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) - if self.args.context_parallel_size > 1: - all_reduce(loss, group=mpu.get_context_parallel_group()) - loss = loss / self.args.context_parallel_size else: raise ValueError(f'Unknown loss type: {self.loss_type}') - # loss = loss.mean() + avg_metric = { 'loss': loss.clone().detach(), } custom_metrics = {} total_lengths = gather(lengths, group=mpu.get_data_parallel_group(with_context_parallel=True)) - total_lengths *= self.args.context_parallel_size custom_metrics = { 'completions/mean_length': total_lengths.float().mean(), 'completions/max_length': total_lengths.float().max(), @@ -1154,7 +1121,9 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): } if self.beta != 0.0: - avg_metric['kl'] = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0).item() + # Unified processing (no CP-specific logic needed) + kl_value = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + avg_metric['kl'] = kl_value.item() mode = 'train' if self.unwrapped_models[0].training else 'eval' if self._metrics[mode]: diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py index 47df78fcdc..1c4efce1c9 100644 --- a/swift/megatron/trainers/rlhf_mixin.py +++ b/swift/megatron/trainers/rlhf_mixin.py @@ -62,6 +62,10 @@ def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None, loss_mask = labels != -100 per_token_logps = per_token_logps * loss_mask if per_token: + # In CP mode, all_gather and reconstruct full sequence + if args.context_parallel_size > 1: + per_token_logps = self._postprocess_packed_tensor_cp(per_token_logps, packed_seq_params, num_samples + or packed_seq_params.num_samples) return per_token_logps if num_samples is None: @@ -74,3 +78,59 @@ def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None, if args.context_parallel_size > 1: all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) return all_logps + + def _postprocess_packed_tensor_cp(self, tensor, packed_seq_params, num_samples): + """ + Generic method: In CP mode, all_gather and reconstruct full tensor sequences. + Works for both logps (float) and masks (bool/int). + + Args: + tensor: [1, packed_len/cp_size] - CP-split tensor (any dtype) + packed_seq_params: PackedSeqParams object + num_samples: Number of samples in the batch + + Returns: + output_full: [1, packed_len] - Full sequence tensor + """ + args = get_args() + cp_size = args.context_parallel_size + cp_rank = mpu.get_context_parallel_rank() + + # All-gather across CP ranks + output_list = [torch.empty_like(tensor) for _ in range(cp_size)] + torch.distributed.all_gather(output_list, tensor.contiguous(), group=mpu.get_context_parallel_group()) + output_list[cp_rank] = tensor + + # Reconstruct full sequence + # Shape: [1, packed_len/cp_size] -> [1, packed_len] + cu_seqlens_full = packed_seq_params.cu_seqlens_q + cu_seqlens_cp = cu_seqlens_full // cp_size + + # Calculate total packed length + total_packed_len = cu_seqlens_full[num_samples].item() + output_full = tensor.new_zeros(1, total_packed_len) + + # Reconstruct each sequence + for i in range(num_samples): + start_full = cu_seqlens_full[i].item() + end_full = cu_seqlens_full[i + 1].item() + seq_len = end_full - start_full + + # Length of each chunk after CP split + chunk_len = seq_len // cp_size + half_chunk = chunk_len // 2 + + # Concatenate from each CP rank's output (load-balanced split) + for j in range(cp_size): + o = output_list[j][0] + start_cp = cu_seqlens_cp[i].item() + + # Get two half chunks (CP's load-balanced split) + o0 = o[start_cp:start_cp + half_chunk] + o1 = o[start_cp + half_chunk:start_cp + chunk_len] + + # Place back to full sequence + output_full[0, start_full + j * half_chunk:start_full + (j + 1) * half_chunk] = o0 + output_full[0, end_full - (j + 1) * half_chunk:end_full - j * half_chunk] = o1 + + return output_full diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index a19fe5caf0..c0e19ccb77 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -105,10 +105,6 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): keys.append('decoder_input') else: keys.append('input_ids') - if hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo': - keys.append('truncated_mask') - keys.append('advantages') - keys.append('completion_mask') packed_seq_params = batch.get('packed_seq_params') if packed_seq_params is None: @@ -377,39 +373,3 @@ def log_gpu_memory(prefix: str = '', info_once: bool = False): logger.info_once(log_msg, hash_id=prefix) else: logger.info(log_msg) - - -def should_filter_lora_parameter(name: str) -> bool: - if 'lora_' in name: - return True - - if 'original_module' in name: - return True - return False - - -def patch_model_for_lora_export(model): - original_named_parameters = model.named_parameters - original_state_dict = model.state_dict - - def filtered_named_parameters(*args, **kwargs): - for name, param in original_named_parameters(*args, **kwargs): - if not should_filter_lora_parameter(name): - yield name, param - - def filtered_state_dict(*args, **kwargs): - state_dict = original_state_dict(*args, **kwargs) - filtered = {} - for name, param in state_dict.items(): - if not should_filter_lora_parameter(name): - filtered[name] = param - return filtered - - model.named_parameters = filtered_named_parameters - model.state_dict = filtered_state_dict - - def restore(): - model.named_parameters = original_named_parameters - model.state_dict = original_state_dict - - return restore From 44920c06a35989e027cbaa238b8df18ba3fa5bee Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 13 Nov 2025 16:25:03 +0800 Subject: [PATCH 61/83] fix server client init in first rank instead of last rank --- swift/megatron/train/rlhf.py | 4 ++-- swift/megatron/trainers/grpo_trainer.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index 21cd9ceb92..d27f5aabb2 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -3,7 +3,7 @@ from swift.llm.train.kto import prepare_kto_dataset from swift.trainers.rlhf_trainer.utils import identity_data_collator -from swift.utils import get_current_device, get_logger, is_master +from swift.utils import get_current_device, get_logger, is_last_rank from ..argument import MegatronRLHFArguments from ..trainers import MegatronDPOTrainer, MegatronGRPOTrainer, MegatronKTOTrainer, MegatronRewardTrainer from .sft import MegatronSft @@ -53,7 +53,7 @@ def _prepare_vllm_client(self): return from swift.trainers.rlhf_trainer.vllm_client import VLLMClient vllm_client = None - if is_master(): + if is_last_rank(): logger.info('Start connecting to vLLM server') vllm_client = VLLMClient( base_urls=self.args.vllm_server_base_url, diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 8f6c3d67dd..0ca5ebc64d 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -551,6 +551,7 @@ def _get_encoded_batch(rollout_batch, advantages): 'completion_mask': labels != -100, 'truncated_mask': truncated_mask, 'advantages': advantages, + 'num_samples': len(rollout_batch), }) return encoded_batch From 7ebfdda67d8187ea24c99f5fccca2c9a1c3f4806 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 13 Nov 2025 17:16:16 +0800 Subject: [PATCH 62/83] fix server mode --- swift/megatron/trainers/grpo_trainer.py | 2 +- swift/trainers/rlhf_trainer/vllm_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 0ca5ebc64d..0e49f6981b 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -734,7 +734,7 @@ def _server_rollout(self, all_outputs = [None] * len(all_requests) if not is_global_inputs: - all_outputs = broadcast_object_list(all_outputs, from_process=0) + all_outputs = broadcast_object_list(all_outputs, from_process=self.world_size - 1) start_idx = sum(all_requests_lengths[:self.process_index]) end_idx = start_idx + all_requests_lengths[self.process_index] outputs = all_outputs[start_idx:end_idx] diff --git a/swift/trainers/rlhf_trainer/vllm_client.py b/swift/trainers/rlhf_trainer/vllm_client.py index 704a5eb63f..2de38550d6 100644 --- a/swift/trainers/rlhf_trainer/vllm_client.py +++ b/swift/trainers/rlhf_trainer/vllm_client.py @@ -213,7 +213,7 @@ def init_communicator(self, device: Union[int, str] = 0): pg = StatelessProcessGroup.create( host=self.hosts[i], port=self.group_ports[i], rank=rank, world_size=world_size) - comm = PyNcclCommunicator(pg, device=0) + comm = PyNcclCommunicator(pg, device=device) self.pynccl_comms.append(comm) atexit.register(self.close_communicator) From 43fc27d171611b32a5e95e9266ec370d48e73e2c Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 13 Nov 2025 20:32:39 +0800 Subject: [PATCH 63/83] fix server pass prompt --- swift/megatron/trainers/grpo_trainer.py | 43 +++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 0e49f6981b..1416dd7184 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -628,6 +628,7 @@ def _generate_completions(self, batch): return batch def _rollout(self, batch) -> List[RolloutOutput]: + batch = self._set_inputs_system(batch) request_config = self._get_request_config() # TODO: server mode if self.vllm_mode == 'server': @@ -1106,7 +1107,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): elif self.loss_type == 'bnpo': loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) elif self.loss_type == 'dr_grpo': - loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) + loss = (per_token_loss * completion_mask).sum() / (micro_batch_size * self.max_completion_length) else: raise ValueError(f'Unknown loss type: {self.loss_type}') @@ -1343,6 +1344,8 @@ def _prepare_metrics(self): 'advantages': deque(maxlen=args.generation_batch_size), } if is_wandb_available(): + # when log profiling, the step is different from the step in the training loop + # here patch wandb log to pop the step argument from wandb.sdk.wandb_run import Run origin_log = Run.log from functools import wraps @@ -1351,7 +1354,6 @@ def _prepare_metrics(self): def log(self, data: dict[str, Any], step: int | None = None, commit: bool | None = None): return origin_log(self, data, None, commit) - # Directly replace the class method, no need for MethodType Run.log = log def _apply_chat_template_to_messages_list(self, messages_list: DataType): @@ -1362,3 +1364,40 @@ def _apply_chat_template_to_messages_list(self, messages_list: DataType): res = self.template.encode(template_inputs) prompts_text.append(self.template.safe_decode(res['input_ids'])) return prompts_text + + def _set_inputs_system(self, batch: DataType) -> DataType: + """ + Ensures the system message is consistently set for all conversations in the batch. + + The template handles the user-defined system message. However, in server mode, + tokenization occurs on the rollout side. To prevent a mismatch where the system + message is set only during training but missing during rollout, this method + injects the default system message into each conversation if not already present. + + Args: + batch: A list of data items, each containing a 'messages' list. + + Returns: + The updated batch with the default system message inserted at the beginning + of each conversation that lacks one. + """ + + if self.vllm_mode != 'server': + return batch + + # Return early if no default system message is defined + if not self.template.template_meta.default_system: + return batch + + # Return early if all conversations already start with a system message + if all(data['messages'][0]['role'] == 'system' for data in batch): + return batch + + # Insert the default system message at the beginning of each conversation + # that doesn't already have one + for data in batch: + messages = data['messages'] + if messages[0]['role'] != 'system': + messages.insert(0, {'role': 'system', 'content': self.template.template_meta.default_system}) + + return batch From 1696ea9cd33fe2b9479d55e4c397fb7032f78d0d Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 13 Nov 2025 23:48:07 +0800 Subject: [PATCH 64/83] dense script --- examples/megatron/grpo/dense_colocate.sh | 66 ++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 examples/megatron/grpo/dense_colocate.sh diff --git a/examples/megatron/grpo/dense_colocate.sh b/examples/megatron/grpo/dense_colocate.sh new file mode 100644 index 0000000000..a9ba1d1aea --- /dev/null +++ b/examples/megatron/grpo/dense_colocate.sh @@ -0,0 +1,66 @@ +# DP size = world_size // (context_parallel_size * tensor_model_parallel_size * pipeline_model_parallel_size) +# = 8 // (1 * 1 * 1) = 8 + +# NOTE: global_batch_size / micro_batch_size is completion-level +# global_batch_size = micro_batch_size * DP size * gradient_accumulation_steps (128) +# generation_batch_size = global_batch_size * steps_per_generation (128 * 4 = 512) +# num_of_prompt_to_rollout = generation_batch_size / num_generations (512 / 8 = 64) +# num_of_prompt_to_train = generation_batch_size / num_generations (128 / 8 = 16) + +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +NPROC_PER_NODE=8 \ +MAX_PIXELS=602112 \ +MASTER_PORT=29600 \ +megatron rlhf \ + --rlhf_type grpo \ + --model Qwen/Qwen2.5-VL-3B-Instruct \ + --load_safetensors true \ + --save_safetensors true \ + --context_parallel_size 1 \ + --tensor_model_parallel_size 1 \ + --pipeline_model_parallel_size 1 \ + --dataset AI-ModelScope/clevr_cogen_a_train#10000 \ + --max_epochs 1 \ + --global_batch_size 128 \ + --micro_batch_size 4 \ + --steps_per_generation 4 \ + --num_generations 8 \ + --external_plugins examples/train/grpo/plugin/plugin.py \ + --reward_funcs external_r1v_acc format \ + --use_vllm true \ + --vllm_mode colocate \ + --vllm_gpu_memory_utilization 0.7 \ + --vllm_max_model_len 10240 \ + --max_length 8192 \ + --max_completion_length 2048 \ + --train_type full \ + --lr 1e-6 \ + --bf16 true \ + --beta 0.001 \ + --importance_sampling_level token \ + --epsilon 0.2 \ + --epsilon_high 0.2 \ + --dynamic_sample false \ + --overlong_filter true \ + --loss_type grpo \ + --sleep_level 2 \ + --offload_model true \ + --offload_optimizer true \ + --log_interval 1 \ + --recompute_granularity selective \ + --max_epochs 1 \ + --finetune \ + --num_workers 8 \ + --dataset_num_proc 8 \ + --no_save_optim \ + --no_save_rng \ + --attention_backend flash \ + --temperature 1.0 \ + --system examples/train/grpo/prompt.txt \ + --padding_free true \ + --log_completions true \ + --wandb_project megatron_swift \ + --wandb_exp_name megatron_grpo \ + --train_iters 100 \ + --eval_interval 1000 \ + --save_interval 1000 From a629fbda85b0f9ce963a611286c068da58aaff5a Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 14 Nov 2025 00:07:48 +0800 Subject: [PATCH 65/83] check batch size params --- examples/megatron/grpo/dense_server.sh | 66 ++++++++++++++++++++++++ swift/megatron/argument/megatron_args.py | 41 ++++++++++----- 2 files changed, 93 insertions(+), 14 deletions(-) create mode 100644 examples/megatron/grpo/dense_server.sh diff --git a/examples/megatron/grpo/dense_server.sh b/examples/megatron/grpo/dense_server.sh new file mode 100644 index 0000000000..16b0ce18de --- /dev/null +++ b/examples/megatron/grpo/dense_server.sh @@ -0,0 +1,66 @@ +# DP size = world_size // (context_parallel_size * tensor_model_parallel_size * pipeline_model_parallel_size) +# = 6 // (1 * 1 * 1) = 8 + +# NOTE: global_batch_size / micro_batch_size is completion-level +# global_batch_size = micro_batch_size * DP size * gradient_accumulation_steps (128) +# generation_batch_size = global_batch_size * steps_per_generation (128 * 4 = 512) +# num_of_prompt_to_rollout = generation_batch_size / num_generations (512 / 8 = 64) +# num_of_prompt_to_train = generation_batch_size / num_generations (128 / 8 = 16) + +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +NPROC_PER_NODE=8 \ +MAX_PIXELS=602112 \ +MASTER_PORT=29600 \ +megatron rlhf \ + --rlhf_type grpo \ + --model Qwen/Qwen2.5-VL-3B-Instruct \ + --load_safetensors true \ + --save_safetensors true \ + --context_parallel_size 1 \ + --tensor_model_parallel_size 1 \ + --pipeline_model_parallel_size 1 \ + --dataset AI-ModelScope/clevr_cogen_a_train#10000 \ + --max_epochs 1 \ + --global_batch_size 128 \ + --micro_batch_size 4 \ + --steps_per_generation 4 \ + --num_generations 8 \ + --external_plugins examples/train/grpo/plugin/plugin.py \ + --reward_funcs external_r1v_acc format \ + --use_vllm true \ + --vllm_mode colocate \ + --vllm_gpu_memory_utilization 0.7 \ + --vllm_max_model_len 10240 \ + --max_length 8192 \ + --max_completion_length 2048 \ + --train_type full \ + --lr 1e-6 \ + --bf16 true \ + --beta 0.001 \ + --importance_sampling_level token \ + --epsilon 0.2 \ + --epsilon_high 0.2 \ + --dynamic_sample false \ + --overlong_filter true \ + --loss_type grpo \ + --sleep_level 2 \ + --offload_model true \ + --offload_optimizer true \ + --log_interval 1 \ + --recompute_granularity selective \ + --max_epochs 1 \ + --finetune \ + --num_workers 8 \ + --dataset_num_proc 8 \ + --no_save_optim \ + --no_save_rng \ + --attention_backend flash \ + --temperature 1.0 \ + --system examples/train/grpo/prompt.txt \ + --padding_free true \ + --log_completions true \ + --wandb_project megatron_swift \ + --wandb_exp_name megatron_grpo \ + --train_iters 100 \ + --eval_interval 1000 \ + --save_interval 1000 diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index f6ee02c24d..a6e4b4e431 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -150,31 +150,44 @@ def _check_not_supported(): pass def _check_batch_params(): + # Set default values if both are None if self.generation_batch_size is None and self.steps_per_generation is None: self.steps_per_generation = 1 self.generation_batch_size = self.global_batch_size * self.steps_per_generation - elif self.generation_batch_size is not None and self.steps_per_generation is None: - # Just ensure the value is divisible by the global batch size + # Both configured - error + elif self.generation_batch_size is not None and self.steps_per_generation is not None: + raise ValueError("'generation_batch_size' and 'steps_per_generation' cannot be both configured") + # Only generation_batch_size configured + elif self.generation_batch_size is not None: if self.generation_batch_size % self.global_batch_size != 0: raise ValueError(f'generation_batch_size ({self.generation_batch_size}) ' - f'must be divisible by the global batch size ({self.global_batch_size}).') + f'must be divisible by global_batch_size ({self.global_batch_size})') self.steps_per_generation = self.generation_batch_size // self.global_batch_size - elif self.generation_batch_size is None and self.steps_per_generation is not None: - self.generation_batch_size = self.global_batch_size * self.steps_per_generation + # Only steps_per_generation configured else: - raise ValueError( - "'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time") + self.generation_batch_size = self.global_batch_size * self.steps_per_generation + world_size = torch.distributed.get_world_size() - # total_model_size = TP × PP × CP, - # data_parallel_size = world_size // total_model_size dp_size = world_size // ( self.pipeline_model_parallel_size * self.tensor_model_parallel_size * self.context_parallel_size) num_rollout_prompt = self.generation_batch_size // self.num_generations - assert num_rollout_prompt % dp_size == 0, ( - f'num_rollout_prompt ({num_rollout_prompt}) = generation_batch_size ' - f'({self.generation_batch_size}) // num_generations ({self.num_generations}) ' - f'must be divisible by the dp size ({dp_size})' - f'please adjust generation_batch_size/steps_per_generation/num_generations to make it divisible') + if num_rollout_prompt % dp_size != 0: + raise ValueError(f'num_rollout_prompt ({num_rollout_prompt}) = generation_batch_size ' + f'({self.generation_batch_size}) // num_generations ({self.num_generations}) ' + f'must be divisible by dp_size ({dp_size}). ' + f'Please adjust generation_batch_size/steps_per_generation/num_generations.') + + per_device_num_rollout_prompt = num_rollout_prompt // dp_size + + if per_device_num_rollout_prompt % self.micro_batch_size != 0: + raise ValueError(f'Per-device rollout prompt count ({per_device_num_rollout_prompt}) = ' + f'(generation_batch_size ({self.generation_batch_size}) // ' + f'num_generations ({self.num_generations})) // dp_size ({dp_size}) ' + f'must be divisible by micro_batch_size ({self.micro_batch_size}). ' + f'Please adjust arguments to satisfy: ' + f'(generation_batch_size // num_generations) // dp_size % ' + f'micro_batch_size == 0') + self.per_device_generation_batch_size = self.generation_batch_size // world_size _check_not_supported() From 852d0f075e0c8e644acfeb0e404e4a7f1e3026f3 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 14 Nov 2025 00:10:24 +0800 Subject: [PATCH 66/83] dense server script --- examples/megatron/grpo/dense_colocate.sh | 2 +- examples/megatron/grpo/dense_server.sh | 25 +++++++++++++++--------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/examples/megatron/grpo/dense_colocate.sh b/examples/megatron/grpo/dense_colocate.sh index a9ba1d1aea..a7b79f2bd8 100644 --- a/examples/megatron/grpo/dense_colocate.sh +++ b/examples/megatron/grpo/dense_colocate.sh @@ -1,7 +1,7 @@ # DP size = world_size // (context_parallel_size * tensor_model_parallel_size * pipeline_model_parallel_size) # = 8 // (1 * 1 * 1) = 8 -# NOTE: global_batch_size / micro_batch_size is completion-level +# NOTE: global_batch_size and micro_batch_size are completion-level # global_batch_size = micro_batch_size * DP size * gradient_accumulation_steps (128) # generation_batch_size = global_batch_size * steps_per_generation (128 * 4 = 512) # num_of_prompt_to_rollout = generation_batch_size / num_generations (512 / 8 = 64) diff --git a/examples/megatron/grpo/dense_server.sh b/examples/megatron/grpo/dense_server.sh index 16b0ce18de..9505c9d37e 100644 --- a/examples/megatron/grpo/dense_server.sh +++ b/examples/megatron/grpo/dense_server.sh @@ -1,14 +1,21 @@ +# MAX_PIXELS=602112 \ +# CUDA_VISIBLE_DEVICES=6,7 \ +# swift rollout \ +# --model Qwen/Qwen2.5-VL-3B-Instruct \ +# --vllm_data_parallel_size 2 \ +# --vllm_max_model_len 10240 + # DP size = world_size // (context_parallel_size * tensor_model_parallel_size * pipeline_model_parallel_size) -# = 6 // (1 * 1 * 1) = 8 +# = 6 // (1 * 1 * 1) = 6 -# NOTE: global_batch_size / micro_batch_size is completion-level -# global_batch_size = micro_batch_size * DP size * gradient_accumulation_steps (128) -# generation_batch_size = global_batch_size * steps_per_generation (128 * 4 = 512) -# num_of_prompt_to_rollout = generation_batch_size / num_generations (512 / 8 = 64) -# num_of_prompt_to_train = generation_batch_size / num_generations (128 / 8 = 16) +# NOTE: global_batch_size and micro_batch_size are completion-level +# global_batch_size = micro_batch_size * DP size * gradient_accumulation_steps (96) +# generation_batch_size = global_batch_size * steps_per_generation (96 * 4 = 384) +# num_of_prompt_to_rollout = generation_batch_size / num_generations (384 / 8 = 48) +# num_of_prompt_to_train = generation_batch_size / num_generations (96 / 8 = 12) -CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ -NPROC_PER_NODE=8 \ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 \ +NPROC_PER_NODE=6 \ MAX_PIXELS=602112 \ MASTER_PORT=29600 \ megatron rlhf \ @@ -21,7 +28,7 @@ megatron rlhf \ --pipeline_model_parallel_size 1 \ --dataset AI-ModelScope/clevr_cogen_a_train#10000 \ --max_epochs 1 \ - --global_batch_size 128 \ + --global_batch_size 96 \ --micro_batch_size 4 \ --steps_per_generation 4 \ --num_generations 8 \ From 2f98eba874db2cc9d092e99f9b0e5f7612b530bb Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 14 Nov 2025 11:11:55 +0800 Subject: [PATCH 67/83] moe script --- examples/megatron/grpo/moe_colocate_full.sh | 55 +++++++++++++++++++++ examples/megatron/grpo/moe_colocate_lora.sh | 53 ++++++++++++++++++++ 2 files changed, 108 insertions(+) create mode 100644 examples/megatron/grpo/moe_colocate_full.sh create mode 100644 examples/megatron/grpo/moe_colocate_lora.sh diff --git a/examples/megatron/grpo/moe_colocate_full.sh b/examples/megatron/grpo/moe_colocate_full.sh new file mode 100644 index 0000000000..7b66688fd9 --- /dev/null +++ b/examples/megatron/grpo/moe_colocate_full.sh @@ -0,0 +1,55 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +NPROC_PER_NODE=8 \ +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +megatron rlhf \ + --rlhf_type grpo \ + --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ + --load_safetensors true \ + --save_safetensors true \ + --context_parallel_size 1 \ + --tensor_model_parallel_size 4 \ + --expert_model_parallel_size 4 \ + --pipeline_model_parallel_size 2 \ + --dataset open-r1/DAPO-Math-17k-Processed \ + --max_epochs 1 \ + --global_batch_size 8 \ + --micro_batch_size 1 \ + --steps_per_generation 1 \ + --num_generations 8 \ + --reward_funcs accuracy format \ + --use_vllm true \ + --vllm_mode colocate \ + --vllm_gpu_memory_utilization 0.4 \ + --vllm_tensor_parallel_size 8 \ + --vllm_max_model_len 16384 \ + --max_length 8192 \ + --max_completion_length 8192 \ + --train_type full \ + --lr 1e-6 \ + --bf16 true \ + --beta 0.00 \ + --importance_sampling_level sequence \ + --epsilon 3e-4 \ + --epsilon_high 4e-4 \ + --dynamic_sample false \ + --overlong_filter true \ + --loss_type grpo \ + --sleep_level 2 \ + --offload_model true \ + --offload_optimizer true \ + --optimizer_cpu_offload true \ + --use_precision_aware_optimizer \ + --log_interval 1 \ + --recompute_granularity selective \ + --finetune \ + --num_workers 8 \ + --dataset_num_proc 8 \ + --no_save_optim \ + --no_save_rng \ + --attention_backend flash \ + --temperature 1.0 \ + --padding_free true \ + --sequence_parallel true \ + --log_completions true \ + --wandb_project megatron_swift \ + --wandb_exp_name megatron_grpo \ diff --git a/examples/megatron/grpo/moe_colocate_lora.sh b/examples/megatron/grpo/moe_colocate_lora.sh new file mode 100644 index 0000000000..361a233e6c --- /dev/null +++ b/examples/megatron/grpo/moe_colocate_lora.sh @@ -0,0 +1,53 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +NPROC_PER_NODE=8 \ +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +megatron rlhf \ + --rlhf_type grpo \ + --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ + --load_safetensors true \ + --save_safetensors true \ + --context_parallel_size 2 \ + --tensor_model_parallel_size 2 \ + --expert_model_parallel_size 4 \ + --pipeline_model_parallel_size 2 \ + --dataset open-r1/DAPO-Math-17k-Processed \ + --max_epochs 1 \ + --global_batch_size 64 \ + --micro_batch_size 2 \ + --steps_per_generation 2 \ + --num_generations 8 \ + --reward_funcs accuracy format \ + --use_vllm true \ + --vllm_mode colocate \ + --vllm_gpu_memory_utilization 0.3 \ + --vllm_tensor_parallel_size 4 \ + --vllm_max_model_len 16384 \ + --max_length 8192 \ + --max_completion_length 8192 \ + --train_type lora \ + --lr 5e-5 \ + --bf16 true \ + --beta 0.00 \ + --importance_sampling_level sequence \ + --epsilon 3e-4 \ + --epsilon_high 4e-4 \ + --dynamic_sample false \ + --overlong_filter true \ + --loss_type grpo \ + --sleep_level 2 \ + --offload_model true \ + --offload_optimizer true \ + --log_interval 1 \ + --recompute_granularity selective \ + --finetune \ + --num_workers 8 \ + --dataset_num_proc 8 \ + --no_save_optim \ + --no_save_rng \ + --attention_backend flash \ + --temperature 1.0 \ + --padding_free true \ + --sequence_parallel true \ + --log_completions true \ + --wandb_project megatron_swift \ + --wandb_exp_name megatron_grpo \ From 1936a83dc2f63183aed1d8e2c73b16ef5bed17ce Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 14 Nov 2025 15:07:59 +0800 Subject: [PATCH 68/83] docs --- docs/source/BestPractices/Embedding.md | 4 +- .../Megatron-SWIFT/Command-line-parameters.md | 120 ++++++++++-------- docs/source/Megatron-SWIFT/GRPO.md | 59 +++++++++ docs/source/index.rst | 1 + docs/source_en/BestPractices/Embedding.md | 4 +- docs/source_en/Megatron-SWIFT/GRPO.md | 59 +++++++++ docs/source_en/index.rst | 1 + 7 files changed, 192 insertions(+), 56 deletions(-) create mode 100644 docs/source/Megatron-SWIFT/GRPO.md create mode 100644 docs/source_en/Megatron-SWIFT/GRPO.md diff --git a/docs/source/BestPractices/Embedding.md b/docs/source/BestPractices/Embedding.md index 5a91180ce1..2fef01666e 100644 --- a/docs/source/BestPractices/Embedding.md +++ b/docs/source/BestPractices/Embedding.md @@ -106,8 +106,8 @@ infonce loss的评测会有下面几个指标: SWIFT提供了两个脚手架训练脚本: -- [gte模型](https://github.com/tastelikefeet/swift/blob/main/examples/train/embedding/train_gte.sh) -- [gme模型](https://github.com/tastelikefeet/swift/blob/main/examples/train/embedding/train_gme.sh) +- [gte模型](https://github.com/modelscope/swift/blob/main/examples/train/embedding/train_gte.sh) +- [gme模型](https://github.com/modelscope/swift/blob/main/examples/train/embedding/train_gme.sh) ## 推理 diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index a3fdde4529..b2ad29e204 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -244,58 +244,6 @@ lora训练: - lora_bias: 默认为`'none'`,可以选择的值: 'none'、'all'。如果你要将bias全都设置为可训练,你可以设置为`'all'`。 - use_rslora: 默认为`False`,是否使用`RS-LoRA`。 - -**DPO参数**: -- ref_load: ref_model的加载路径。采用DPO/GRPO/KTO算法且使用全参数训练时需要传入。默认为None,即设置为`load`。 -- ref_adapter_load: 加载ref_adapter的权重路径,默认为None。若你要使用SFT产生的LoRA权重进行DPO,请使用"ms-swift>=3.8",并在训练时设置`--adapter_load sft_ckpt --ref_adapter_load sft_ckpt --finetune true`。若是此场景的断点续训,则设置`--adapter_load rlhf_ckpt --ref_adapter_load sft_ckpt --finetune false`。 -- beta: 含义与[TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig)相同。控制与参考模型偏差程度的参数。beta值越高,表示与参考模型的偏差越小。对于 IPO 损失函数 (loss_type="ipo"),beta是[论文](https://huggingface.co/papers/2310.12036)中所指的正则化参数。默认为0.1。 -- 🔥rpo_alpha: 来自[RPO 论文](https://huggingface.co/papers/2404.19733)中的参数,用于控制损失函数中NLL项的权重(即SFT损失),`loss = dpo_loss + rpo_alpha * sft_loss`,论文中推荐设置为`1.`。默认为`None`,即默认不引入sft_loss。 - - **注意**:在"ms-swift<3.8",其默认值为`1.`。在"ms-swift>=3.8"该默认值修改为`None`。 -- reference_free: 是否忽略提供的参考模型,并隐式地使用一个对所有响应赋予相等概率的参考模型。默认为False。 -- label_smoothing: 默认为0.。 -- f_divergence_type: 默认为`reverse_kl`。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer)。 -- loss_type: 默认为'sigmoid'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer#loss-functions)。 - -**KTO参数**: -- ref_load: 含义同DPO。 -- ref_adapter_load: 含义同DPO。 -- beta: 控制与 ref_model 偏离程度的参数。较高的 beta 表示与 ref_model 偏离更小。默认为`0.1`。 -- loss_type: 默认为'kto'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/kto_trainer#trl.KTOConfig.loss_type)。 -- desirable_weight: 抵消 desirable 和 undesirable 数量不均衡的影响,对 desirable 损失按该系数进行加权,默认为`1.`。 -- undesirable_weight: 抵消 desirable 和 undesirable 数量不均衡的影响,对 undesirable 损失按该系数进行加权,默认为`1.`。 - -**GRPO参数** -- ref_load: 含义同DPO。 -- ref_adapter_load: 含义同DPO。 -- beta: KL正则系数,默认为0.04,设置为0时不加载ref model。 -- epsilon: clip 系数,默认为0.2。 -- epsilon_high: upper clip 系数,默认为None,设置后与epsilon共同构成[epsilon, epsilon_high]裁剪范围。 -- overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。 -- importance_sampling_level: 控制重要性采样比计算,可选项为 `token` 、 `sequence` 和 `sequence_token`,默认为`token`。具体参考[GSPO文档](../Instruction/GRPO/AdvancedResearch/GSPO.md) -- batch size 相关参数(注意以下均为 completion-level) - - micro_batch_size: 每个device的批次大小,默认为1。 - - global_batch_size: 总批次大小,等价于`micro_batch_size*数据并行大小*梯度累加步数`。默认为16。对应每次更新权重的训练数据大小(mini_batch_size) - - generation_batch_size: 采样批量大小,需要是global_batch_size的倍数,默认等于global_batch_size - - steps_per_generation:每轮生成的优化步数,即采样批量大小相对global_batch_size的倍数,默认为1。 - - num_generations:每个prompt采样的数量,论文中的G值。采样批量大小需被num_generations 整除。默认为 8。 -- reward_funcs: GRPO算法奖励函数,可选项为`accuracy`、`format`、`cosine`、`repetition`和`soft_overlong`,见swift/plugin/orm.py。你也可以在plugin中自定义自己的奖励函数。默认为`[]`。 -- reward_weights: 每个奖励函数的权重。必须与奖励函数和奖励模型的总数量匹配。如果为 None,则所有奖励的权重都相等,为`1.0`。 -- loss_type: loss 归一化的类型,可选项为['grpo', 'bnpo', 'dr_grpo'], 默认为'grpo', 具体查看该[pr](https://github.com/huggingface/trl/pull/3256#discussion_r2033213348)。 -- vllm_mode 参数 - - vllm_gpu_memory_utilization: vllm透传参数,默认为0.9。 - - vllm_max_model_len: vllm透传参数,默认为None。 - - vllm_enforce_eager: vllm透传参数,默认为False。 - - vllm_limit_mm_per_prompt: vllm透传参数,默认为None。 - - vllm_enable_prefix_caching: vllm透传参数,默认为True。 - - sleep_level: 训练时释放 vLLM 显存,可选项为[0, 1], 默认为0,不释放 - - offload_optimizer: 是否在vLLM推理时offload optimizer参数,默认为False。 - - offload_model: 是否在vLLM推理时 offload 模型,默认为False。 - -内置奖励函数参数参考[文档](../Instruction/命令行参数.md#奖励函数参数) - -**RM参数**: -- center_rewards_coefficient: 用于激励奖励模型输出均值为零的奖励的系数,具体查看这篇[论文](https://huggingface.co/papers/2312.09244)。推荐值:0.01。 - **Mcore-Bridge参数** - 🔥load_safetensors: 默认为False,是否直接从safetensors加载权重。 - 🔥save_safetensors: 默认为False,是否直接保存成safetensors权重。注意,若该参数设置为True,则不会存储优化器权重、随机数状态等断点续训内容。 @@ -344,6 +292,74 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用 - calculate_per_token_loss: 覆盖Megatron参数,默认为False。 +### DPO参数 +- ref_load: ref_model的加载路径。采用DPO/GRPO/KTO算法且使用全参数训练时需要传入。默认为None,即设置为`load`。 +- ref_adapter_load: 加载ref_adapter的权重路径,默认为None。若你要使用SFT产生的LoRA权重进行DPO,请使用"ms-swift>=3.8",并在训练时设置`--adapter_load sft_ckpt --ref_adapter_load sft_ckpt --finetune true`。若是此场景的断点续训,则设置`--adapter_load rlhf_ckpt --ref_adapter_load sft_ckpt --finetune false`。 +- beta: 含义与[TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig)相同。控制与参考模型偏差程度的参数。beta值越高,表示与参考模型的偏差越小。对于 IPO 损失函数 (loss_type="ipo"),beta是[论文](https://huggingface.co/papers/2310.12036)中所指的正则化参数。默认为0.1。 +- 🔥rpo_alpha: 来自[RPO 论文](https://huggingface.co/papers/2404.19733)中的参数,用于控制损失函数中NLL项的权重(即SFT损失),`loss = dpo_loss + rpo_alpha * sft_loss`,论文中推荐设置为`1.`。默认为`None`,即默认不引入sft_loss。 + - **注意**:在"ms-swift<3.8",其默认值为`1.`。在"ms-swift>=3.8"该默认值修改为`None`。 +- reference_free: 是否忽略提供的参考模型,并隐式地使用一个对所有响应赋予相等概率的参考模型。默认为False。 +- label_smoothing: 默认为0.。 +- f_divergence_type: 默认为`reverse_kl`。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer)。 +- loss_type: 默认为'sigmoid'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer#loss-functions)。 + +### KTO参数 +- ref_load: 含义同DPO。 +- ref_adapter_load: 含义同DPO。 +- beta: 控制与 ref_model 偏离程度的参数。较高的 beta 表示与 ref_model 偏离更小。默认为`0.1`。 +- loss_type: 默认为'kto'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/kto_trainer#trl.KTOConfig.loss_type)。 +- desirable_weight: 抵消 desirable 和 undesirable 数量不均衡的影响,对 desirable 损失按该系数进行加权,默认为`1.`。 +- undesirable_weight: 抵消 desirable 和 undesirable 数量不均衡的影响,对 undesirable 损失按该系数进行加权,默认为`1.`。 + +### RM参数 +- center_rewards_coefficient: 用于激励奖励模型输出均值为零的奖励的系数,具体查看这篇[论文](https://huggingface.co/papers/2312.09244)。推荐值:0.01。 + +### GRPO参数 +- ref_load: 含义同DPO。 +- ref_adapter_load: 含义同DPO。 +- beta: KL正则系数,默认为0.04,设置为0时不加载ref model。 +- micro_batch_size: 每个device的批次大小,默认为1。 +- global_batch_size: 总批次大小,等价于`micro_batch_size*数据并行大小*梯度累加步数`。默认为16。 +- steps_per_generation:每轮生成的优化步数,即采样批量大小相对global_batch_size的倍数,默认为1。 +- generation_batch_size: 采样批量大小,需要是global_batch_size的倍数,默认等于global_batch_size*steps_per_generation。 +- num_generations: 每个prompt采样的数量,论文中的G值,默认为8。 +- reward_funcs: GRPO算法奖励函数,可选项为`accuracy`、`format`、`cosine`、`repetition`和`soft_overlong`,见swift/plugin/orm.py。你也可以在plugin中自定义自己的奖励函数。默认为`[]`。 +- reward_weights: 每个奖励函数的权重。必须与奖励函数和奖励模型的总数量匹配。默认为 None,即所有奖励的权重都相等,为`1.0`。 + - 提示:如果GRPO训练中包含`--reward_model`,则其加在奖励函数的最后位置。 +- loss_type: loss 归一化的类型,可选项为['grpo', 'bnpo', 'dr_grpo'], 默认为'grpo', 具体查看该[pr](https://github.com/huggingface/trl/pull/3256#discussion_r2033213348)。 +- log_completions: 是否记录训练中的模型生成内容,默认为False。 +- vllm_mode: vLLM 集成模式,可选项为 `server` 和 `colocate`。server 模式使用 `swift rollout` 拉起的 vLLM 服务器进行采样,colocate 模式在程序内部署 vLLM。使用server端时, +- vllm_mode server 参数 + - vllm_server_base_url: vLLM server的Base URL(比如 http://local_host:8000), 默认为None。设置后,忽略host和port设置。 + - vllm_server_host:vLLM server host地址,默认为None。 + - vllm_server_port vLLM server 服务端口,默认为8000。 + - vllm_server_timeout 连接vLLM server的超时时间,默认为 240s。 + - vllm_server_pass_dataset: 透传额外的数据集信息到vLLM server,用于多轮训练。 + - async_generate: 异步rollout以提高训练速度,注意开启时采样会使用上一轮更新的模型进行采样,不支持多轮场景。默认`false`. + - SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE:环境变量,用于控制权重同步时的传输桶大小(bucket size),适用于 Server Mode 下的全参数训练,单位为 MB,默认值为 512 MB。 +- vllm_mode colocate 参数(更多参数支持参考[vLLM参数](#vLLM参数)。) + - vllm_gpu_memory_utilization: vllm透传参数,默认为0.9。 + - vllm_max_model_len: vllm透传参数,默认为None。 + - vllm_enforce_eager: vllm透传参数,默认为False。 + - vllm_limit_mm_per_prompt: vllm透传参数,默认为None。 + - vllm_enable_prefix_caching: vllm透传参数,默认为True。 + - vllm_tensor_parallel_size: tp并行数,默认为`1`。 + - vllm_enable_lora: 支持vLLM Engine 加载 LoRA adapter,默认为False。用于加速LoRA训练的权重同步,具体参考[文档](../Instruction/GRPO/GetStarted/GRPO.md#权重同步加速)。 + - sleep_level: 训练时释放 vLLM 显存,可选项为[0, 1, 2], 默认为0,不释放。 + - offload_optimizer: 是否在vLLM推理时offload optimizer参数,默认为False。 + - offload_model: 是否在vLLM推理时 offload 模型,默认为False。 +- num_iterations: 每条数据的更新次数,[GRPO论文](https://arxiv.org/abs/2402.03300)中的 $\mu$ 值,默认为1。 +- epsilon: clip 系数,默认为0.2。 +- epsilon_high: upper clip 系数,默认为None,设置后与epsilon共同构成[epsilon, epsilon_high]裁剪范围。 +- dynamic_sample:筛除group内奖励标准差为0的数据,额外采样新数据,默认为False。 +- max_resample_times:dynamic_sample设置下限制重采样次数,默认3次。 +- overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。 +- delta: [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291)中双侧 GRPO 上界裁剪值。若设置,建议大于 1 + epsilon。默认为None。 +- importance_sampling_level: 控制重要性采样比计算,可选项为 `token` 和 `sequence`,`token` 模式下保留原始的每个 token 的对数概率比,`sequence` 模式下则会对序列中所有有效 token 的对数概率比进行平均。[GSPO论文](https://www.arxiv.org/abs/2507.18071)中使用sequence级别计算来稳定训练,默认为`token`。 +- scale_rewards:指定奖励的缩放策略。可选值包括 `group`(按组内标准差缩放)、`batch`(按整个批次的标准差缩放)、`none`(不进行缩放)。在 ms-swift < 3.10 版本中,该参数为布尔类型,`true` 对应 `group`,`false` 对应 `none`。默认值与 `advantage_estimator` 绑定:`grpo` 对应 `group`,`rloo` 对应 `none`,`reinforce_plus_plus` 对应 `batch`。 + +内置奖励函数参数参考[文档](../Instruction/Command-line-parameters.md#奖励函数参数) + ## 导出参数 这里介绍`megatron export`的参数(需"ms-swift>=3.10"),若要使用`swift export`导出命令,请参考[ms-swift命令行参数文档](../Instruction/Command-line-parameters.md#导出参数)。`megatron export`相比`swift export`,支持分布式和多机导出。Megatron导出参数继承自Megatron参数和基本参数。 - 🔥to_mcore: HF格式权重转成Megatron格式。默认为False。 diff --git a/docs/source/Megatron-SWIFT/GRPO.md b/docs/source/Megatron-SWIFT/GRPO.md new file mode 100644 index 0000000000..79c0080be8 --- /dev/null +++ b/docs/source/Megatron-SWIFT/GRPO.md @@ -0,0 +1,59 @@ +# GRPO + +**版本依赖**:ms-swift >= 3.11 + +如果你是首次使用 GRPO,请先参考 [GRPO文档](../Instruction/GRPO/GetStarted/GRPO.md)。 + +Megatron GRPO 当前已支持以下功能: + +- **训练模式**:全参数训练与 LoRA 微调 +- **并行策略**:支持上下文并行(CP)、流水线并行(PP)、张量并行(TP)和专家并行(EP) +- **推理加速**:支持 vLLM 的 colocate 模式和 server 模式 +- **模型支持**:兼容 Megatron Swift 中的 LLM 及 MLLM(多模态大模型) +- **算法支持**:涵盖 swift GRPO 的大部分功能 + +以下参数或功能将在后续版本中逐步支持: + +- **Entropy 相关配置**:如 `top_entropy_quantile`、`log_entropy` +- **Reward Model / Reward Model Plugin** +- **多轮 Rollout 调度机制**(`multi_turn_scheduler`):实现多轮对话策略优化 +- **优势估计器**(`advantage_estimator`):支持更复杂的策略梯度估计方法 +- **KL 散度计入奖励**(`kl_in_reward`) +- **虚拟流水线并行**(VPP) +- **参考模型同步更新**(`sync_ref_model`) +- **日志同步 SwanLab** + +⚠️ 注意:以下参数在 Megatron GRPO 中不生效: + +- **`use_vllm`**:Megatron GRPO 暂不支持使用 PTEngine 进行 Rollout 推理。 +- **`move_model_batches`**:该参数专用于 DeepSpeed ZeRO-3 优化,在 Megatron 架构下无效。 + +与 ms-swift GRPO 相同,Megatron GRPO batch size 相关的参数均以 **completion-level** 为单位,即表示模型生成的 completion 数量,而非 prompt 数量。 + +#### 参数对比 + +下表对比了 ms-swift 和 Megatron-SWIFT 中批量相关参数的对应关系: + +| ms-swift 参数 | Megatron-SWIFT 参数 | 说明 | +|---------------|---------------------|------| +| `per_device_train_batch_size` | `micro_batch_size` | 每张 GPU 的训练批次大小(completion-level) | +| `gradient_accumulation_steps` | - | 梯度累积步数,在 Megatron-SWIFT 中已包含在 `global_batch_size` 的计算中 | +| - | `global_batch_size` | 全局批次大小(completion-level)
**Megatron-SWIFT**: `micro_batch_size × dp_size × gradient_accumulation_steps`
**ms-swift**: `per_device_train_batch_size × world_size × gradient_accumulation_steps` | +| `num_generations` | `num_generations` | 每个 prompt 生成的 completion 数量 | +| `steps_per_generation` | `steps_per_generation` | Rollout 批次大小相对于训练批次大小的倍数
**注意**:在 ms-swift 中需为 `gradient_accumulation_steps` 的整数倍 | +| `generation_batch_size` | `generation_batch_size` | Rollout 阶段的批次大小(completion-level),需为 `global_batch_size` 的整数倍 | + +以下公式用于计算 Megatron GRPO 中的批量: + +- **数据并行大小**:`dp_size = world_size / (TP × PP × CP)` +- **全局批次大小**:`global_batch_size = micro_batch_size × dp_size × gradient_accumulation_steps` +- **生成批次大小**:`generation_batch_size = global_batch_size × steps_per_generation` +- **Rollout Prompt 数量**:`num_rollout_prompts = generation_batch_size / num_generations` +- **训练 Prompt 数量**:`num_train_prompts = global_batch_size / num_generations` +- **每个 DP group 的训练 Prompt 数量**:`num_prompts_per_dp_group = global_batch_size / num_generations / dp_size` + +注意:在 Megatron GRPO 中,每个 DP group 的训练 Prompt 数量须满足 `num_prompts_per_dp_group` 是 `micro_batch_size`的整数倍,以确保训练批次能够正确分配。 + +更多参数请参考[命令行文档](./Command-line-parameters.md#grpo参数) + +训练脚本请参考[Megatron GRPO 脚本](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/grpo) diff --git a/docs/source/index.rst b/docs/source/index.rst index c5a5fc08c8..f70a8a05c9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -42,6 +42,7 @@ Swift DOCUMENTATION Megatron-SWIFT/LoRA-Training.md Megatron-SWIFT/Multimodal-Model.md Megatron-SWIFT/Mcore-Bridge.md + Megatron-SWIFT/GRPO.md .. toctree:: diff --git a/docs/source_en/BestPractices/Embedding.md b/docs/source_en/BestPractices/Embedding.md index 5794e95d23..3f97f07814 100644 --- a/docs/source_en/BestPractices/Embedding.md +++ b/docs/source_en/BestPractices/Embedding.md @@ -108,8 +108,8 @@ The evaluation of InfoNCE loss includes the following metrics: SWIFT provides two scaffold training scripts: -- [GTE Model](https://github.com/tastelikefeet/swift/blob/main/examples/train/embedding/train_gte.sh) -- [GME Model](https://github.com/tastelikefeet/swift/blob/main/examples/train/embedding/train_gme.sh) +- [GTE Model](https://github.com/modelscope/swift/blob/main/examples/train/embedding/train_gte.sh) +- [GME Model](https://github.com/modelscope/swift/blob/main/examples/train/embedding/train_gme.sh) ## Inference diff --git a/docs/source_en/Megatron-SWIFT/GRPO.md b/docs/source_en/Megatron-SWIFT/GRPO.md new file mode 100644 index 0000000000..d72b872f17 --- /dev/null +++ b/docs/source_en/Megatron-SWIFT/GRPO.md @@ -0,0 +1,59 @@ +# Megatron GRPO + +**Version Requirement**: ms-swift >= 3.11 + +If you are new to GRPO, please refer to the [GRPO documentation](../Instruction/GRPO/GetStarted/GRPO.md) first. + +Megatron GRPO currently supports the following features: + +- **Training Modes**: Full parameter training and LoRA fine-tuning +- **Parallelism Strategies**: Context Parallelism (CP), Pipeline Parallelism (PP), Tensor Parallelism (TP), and Expert Parallelism (EP) +- **Inference Acceleration**: vLLM colocate mode and server mode +- **Model Support**: Compatible with LLMs and MLLMs (multimodal large models) in Megatron Swift +- **Algorithm Support**: Covers most features of Swift GRPO + +The following parameters or features will be gradually supported in future versions: + +- **Entropy-related Configuration**: e.g., `top_entropy_quantile`, `log_entropy` +- **Reward Model / Reward Model Plugin** +- **Multi-turn Rollout Scheduling** (`multi_turn_scheduler`): Multi-turn conversation policy optimization +- **Advantage Estimator** (`advantage_estimator`): Support for more complex policy gradient estimation methods +- **KL Divergence in Reward** (`kl_in_reward`) +- **Virtual Pipeline Parallelism** (VPP) +- **Reference Model Synchronization** (`sync_ref_model`) +- **SwanLab Logging Integration** + +⚠️ **Note**: The following parameters are not effective in Megatron GRPO: + +- **`use_vllm`**: Megatron GRPO does not support using PTEngine for Rollout inference. +- **`move_model_batches`**: This parameter is specific to DeepSpeed ZeRO-3 optimization and is invalid in the Megatron architecture. + +Similar to ms-swift GRPO, all batch size-related parameters in Megatron GRPO are at the **completion-level**, meaning they represent the number of completions generated by the model, not the number of prompts. + +#### Parameter Comparison + +The following table compares the batch-related parameters between ms-swift and Megatron-SWIFT: + +| ms-swift Parameter | Megatron-SWIFT Parameter | Description | +|-------------------|--------------------------|-------------| +| `per_device_train_batch_size` | `micro_batch_size` | Training batch size per GPU (completion-level) | +| `gradient_accumulation_steps` | - | Gradient accumulation steps, already included in `global_batch_size` calculation in Megatron-SWIFT | +| - | `global_batch_size` | Global batch size (completion-level)
**Megatron-SWIFT**: `micro_batch_size × dp_size × gradient_accumulation_steps`
**ms-swift**: `per_device_train_batch_size × world_size × gradient_accumulation_steps` | +| `num_generations` | `num_generations` | Number of completions generated per prompt | +| `steps_per_generation` | `steps_per_generation` | Ratio of Rollout batch size to training batch size
**Note**: In ms-swift, must be an integer multiple of `gradient_accumulation_steps` | +| `generation_batch_size` | `generation_batch_size` | Batch size during Rollout phase (completion-level), must be an integer multiple of `global_batch_size` | + +The following formulas are used to calculate batch sizes in Megatron GRPO: + +- **Data Parallel Size**: `dp_size = world_size / (TP × PP × CP)` +- **Global Batch Size**: `global_batch_size = micro_batch_size × dp_size × gradient_accumulation_steps` +- **Generation Batch Size**: `generation_batch_size = global_batch_size × steps_per_generation` +- **Rollout Prompt Count**: `num_rollout_prompts = generation_batch_size / num_generations` +- **Training Prompt Count**: `num_train_prompts = global_batch_size / num_generations` +- **Training Prompt Count per DP Group**: `num_prompts_per_dp_group = global_batch_size / num_generations / dp_size` + +**Note**: In Megatron GRPO, the training prompt count per DP group must satisfy that `num_prompts_per_dp_group` is an integer multiple of `micro_batch_size` to ensure proper batch allocation during training. + +For more parameters, please refer to the [Command-line Parameters documentation](./Command-line-parameters.md#grpo-parameters). + +For training scripts, please refer to [Megatron GRPO Scripts](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/grpo). diff --git a/docs/source_en/index.rst b/docs/source_en/index.rst index c5a5fc08c8..f70a8a05c9 100644 --- a/docs/source_en/index.rst +++ b/docs/source_en/index.rst @@ -42,6 +42,7 @@ Swift DOCUMENTATION Megatron-SWIFT/LoRA-Training.md Megatron-SWIFT/Multimodal-Model.md Megatron-SWIFT/Mcore-Bridge.md + Megatron-SWIFT/GRPO.md .. toctree:: From bdcaa51fd74c07fb3bfb3acd0e6a29253f1657e4 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 14 Nov 2025 15:22:13 +0800 Subject: [PATCH 69/83] update doc --- .../Instruction/Command-line-parameters.md | 2 +- docs/source/Instruction/Use-tuners.md | 2 +- .../Megatron-SWIFT/Command-line-parameters.md | 127 ++++++++++-------- 3 files changed, 74 insertions(+), 57 deletions(-) diff --git a/docs/source/Instruction/Command-line-parameters.md b/docs/source/Instruction/Command-line-parameters.md index ef27b8a304..2eda79118f 100644 --- a/docs/source/Instruction/Command-line-parameters.md +++ b/docs/source/Instruction/Command-line-parameters.md @@ -616,7 +616,7 @@ reward模型参数将在PPO、GRPO中使用。 - log_entropy: 记录训练中的熵值变化动态,默认为False,具体参考[文档](./GRPO/GetStarted/GRPO.md#logged-metrics) ##### 奖励函数参数 -内置的奖励函数参考[文档](./GRPO/DeveloperGuide/奖励函数.md) +内置的奖励函数参考[文档](./GRPO/DeveloperGuide/reward_function.md) cosine 奖励参数 - cosine_min_len_value_wrong:cosine 奖励函数参数,生成错误答案时,最小长度对应的奖励值。默认值为-0.5。 - cosine_max_len_value_wrong:生成错误答案时,最大长度对应的奖励值。默认值为0.0。 diff --git a/docs/source/Instruction/Use-tuners.md b/docs/source/Instruction/Use-tuners.md index c84ca6fe0c..7461877fc8 100644 --- a/docs/source/Instruction/Use-tuners.md +++ b/docs/source/Instruction/Use-tuners.md @@ -15,7 +15,7 @@ tuner是指附加在模型上的额外结构部分,用于减少训练参数量 - Adapter: [Parameter-Efficient Transfer Learning for NLP](http://arxiv.org/abs/1902.00751) - Vision Prompt Tuning: [Visual Prompt Tuning](https://arxiv.org/abs/2203.12119) - Side: [Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks](https://arxiv.org/abs/1912.13503) -- Res-Tuning: [Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone](https://arxiv.org/abs/2310.19859) < [arXiv](https://arxiv.org/abs/2310.19859) | [Project Page](https://res-tuning.github.io/) | [Usage](ResTuning.md) > +- Res-Tuning: [Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone](https://arxiv.org/abs/2310.19859) < [arXiv](https://arxiv.org/abs/2310.19859) | [Project Page](https://res-tuning.github.io/) > - [PEFT](https://github.com/huggingface/peft)提供的tuners, 如AdaLoRA、DoRA、Fourierft等 ## 接口列表 diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 31d2f33752..feea8a4755 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -262,61 +262,6 @@ LoRA Training: - lora_bias: Default is `'none'`. Available options: `'none'`, `'all'`. If you want all biases to be set as trainable, set this to `'all'`. - use_rslora: Default is `False`. Whether to use `RS-LoRA`. -**DPO Parameters** -- ref_load: The loading path for the reference model. This must be provided when using DPO/GRPO/KTO algorithms with full-parameter training. Defaults to `None`, which means it will be set to the same value as `load`. -- ref_adapter_load: The path to load the ref_adapter weights, default is `None`. If you want to use LoRA weights generated from SFT for DPO, please use "ms-swift>=3.8" and set `--adapter_load sft_ckpt --ref_adapter_load sft_ckpt --finetune true` during training. For resuming training from a checkpoint in this scenario, set `--adapter_load rlhf_ckpt --ref_adapter_load sft_ckpt --finetune false`. -- beta: Has the same meaning as in [TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig). It controls the degree of deviation from the reference model. A higher beta value indicates less deviation from the reference model. For the IPO loss function (`loss_type="ipo"`), beta is the regularization parameter as mentioned in the [paper](https://huggingface.co/papers/2310.12036). Default is 0.1. -- 🔥rpo_alpha: A parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) that controls the weight of the NLL term (i.e., the SFT loss) in the loss function, where `loss = dpo_loss + rpo_alpha * sft_loss`. The paper recommends setting it to `1.`. The default value is `None`, meaning the SFT loss is not included by default. - - **Note**: In "ms-swift<3.8", the default value was `1.`. Starting from "ms-swift>=3.8", the default has been changed to `None`. -- reference_free: Whether to ignore the provided reference model and implicitly use a reference model that assigns equal probability to all responses. Default is `False`. -- label_smoothing: Default is 0. -- f_divergence_type: Default is `reverse_kl`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer) for possible values. -- loss_type: Default is `'sigmoid'`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer#loss-functions) for possible values. - -**KTO Parameters**: -- ref_load: same meaning as in DPO. -- ref_adapter_load: same meaning as in DPO. -- beta: parameter controlling the deviation from the ref_model. Higher `beta` means less deviation from the ref_model. Default is `0.1`. -- loss_type: default is `'kto'`. See possible values in the TRL docs: https://huggingface.co/docs/trl/main/en/kto_trainer#trl.KTOConfig.loss_type. -- desirable_weight: factor to weight desirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`. -- undesirable_weight: factor to weight undesirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`. - -**RM Parameters**: -- center_rewards_coefficient: A coefficient used in reward model (RM) training to incentivize the model to output rewards with zero mean. See this [paper](https://huggingface.co/papers/2312.09244) for details. Recommended value: 0.01. - -**GRPO Parameters** -- ref_load: Same meaning as in DPO. -- ref_adapter_load: Same meaning as in DPO. -- beta: KL regularization coefficient, default is 0.04. When set to 0, the reference model is not loaded. -- epsilon: Clip coefficient, default is 0.2. -- epsilon_high: Upper clip coefficient, default is None. When set, forms a clipping range [epsilon, epsilon_high] together with epsilon. -- overlong_filter: Skips samples that are truncated due to excessive length and excludes them from loss computation. Default is False. -- importance_sampling_level: Controls the level at which importance sampling ratios are computed. Options are `token`, `sequence`, and `sequence_token`. Default is `token`. See [GSPO Documentation](../Instruction/GRPO/AdvancedResearch/GSPO.md) for details. -- Batch Size Related Parameters (Note: all are completion-level) - - micro_batch_size: Batch size per device, default is 1. - - global_batch_size: Total batch size, equivalent to `micro_batch_size * data parallelism size * gradient accumulation steps`. Default is 16. Corresponds to the mini_batch_size (number of training samples per weight update). - - generation_batch_size: Sampling batch size, must be a multiple of global_batch_size. Default equals global_batch_size. - - steps_per_generation: Number of optimization steps per generation round, i.e., the ratio of generation_batch_size to global_batch_size. Default is 1. - - num_generations: Number of samples generated per prompt (the "G" value in the paper). generation_batch_size must be divisible by num_generations. Default is 8. -- reward_funcs: Reward functions used in GRPO algorithm. Options include `accuracy`, `format`, `cosine`, `repetition`, and `soft_overlong`, defined in swift/plugin/orm.py. You can also customize your own reward functions in the plugin. Default is `[]`. -- reward_weights: Weights assigned to each reward function. Must match the total number of reward functions and reward models. If None, all rewards are equally weighted with `1.0`. -- loss_type: Type of loss normalization. Options are ['grpo', 'bnpo', 'dr_grpo']. Default is 'grpo'. See this [PR](https://github.com/huggingface/trl/pull/3256#discussion_r2033213348) for details. - -- vLLM Parameters - - vllm_gpu_memory_utilization: Pass-through parameter to vLLM, default is 0.9. - - vllm_max_model_len: Pass-through parameter to vLLM, default is None. - - vllm_enforce_eager: Pass-through parameter to vLLM, default is False. - - vllm_limit_mm_per_prompt: Pass-through parameter to vLLM, default is None. - - vllm_enable_prefix_caching: Pass-through parameter to vLLM, default is True. - - sleep_level: Release vLLM GPU memory during training. Options are [0, 1], default is 0 (no release). - - offload_optimizer: Whether to offload optimizer states during vLLM inference. Default is False. - - offload_model: Whether to offload model weights during vLLM inference. Default is False. - -For built-in reward function parameters, refer to the [documentation](../Instruction/GRPO/DeveloperGuide/reward_function.md). - -**RM Parameters**: -- center_rewards_coefficient: A coefficient used in reward model (RM) training to incentivize the model to output rewards with zero mean. See this [paper](https://huggingface.co/papers/2312.09244) for details. Recommended value: 0.01. - **Mcore-Bridge Parameters** - 🔥load_safetensors: Defaults to False. Whether to load weights directly from safetensors. @@ -375,6 +320,78 @@ In addition to inheriting the training parameters, the following parameters are - calculate_per_token_loss: Overrides the Megatron parameter. Default is False. +### DPO Parameters + +- ref_load: The loading path for the reference model. This must be provided when using DPO/GRPO/KTO algorithms with full-parameter training. Defaults to `None`, which means it will be set to the same value as `load`. +- ref_adapter_load: The path to load the ref_adapter weights, default is `None`. If you want to use LoRA weights generated from SFT for DPO, please use "ms-swift>=3.8" and set `--adapter_load sft_ckpt --ref_adapter_load sft_ckpt --finetune true` during training. For resuming training from a checkpoint in this scenario, set `--adapter_load rlhf_ckpt --ref_adapter_load sft_ckpt --finetune false`. +- beta: Has the same meaning as in [TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig). It controls the degree of deviation from the reference model. A higher beta value indicates less deviation from the reference model. For the IPO loss function (`loss_type="ipo"`), beta is the regularization parameter as mentioned in the [paper](https://huggingface.co/papers/2310.12036). Default is 0.1. +- 🔥rpo_alpha: A parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) that controls the weight of the NLL term (i.e., the SFT loss) in the loss function, where `loss = dpo_loss + rpo_alpha * sft_loss`. The paper recommends setting it to `1.`. The default value is `None`, meaning the SFT loss is not included by default. + - **Note**: In "ms-swift<3.8", the default value was `1.`. Starting from "ms-swift>=3.8", the default has been changed to `None`. +- reference_free: Whether to ignore the provided reference model and implicitly use a reference model that assigns equal probability to all responses. Default is `False`. +- label_smoothing: Default is 0. +- f_divergence_type: Default is `reverse_kl`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer) for possible values. +- loss_type: Default is `'sigmoid'`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer#loss-functions) for possible values. + +### KTO Parameters + +- ref_load: same meaning as in DPO. +- ref_adapter_load: same meaning as in DPO. +- beta: parameter controlling the deviation from the ref_model. Higher `beta` means less deviation from the ref_model. Default is `0.1`. +- loss_type: default is `'kto'`. See possible values in the TRL docs: https://huggingface.co/docs/trl/main/en/kto_trainer#trl.KTOConfig.loss_type. +- desirable_weight: factor to weight desirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`. +- undesirable_weight: factor to weight undesirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`. + +### RM Parameters + +- center_rewards_coefficient: A coefficient used in reward model (RM) training to incentivize the model to output rewards with zero mean. See this [paper](https://huggingface.co/papers/2312.09244) for details. Recommended value: 0.01. + +### GRPO Parameters + +- ref_load: Same meaning as in DPO. +- ref_adapter_load: Same meaning as in DPO. +- beta: KL regularization coefficient, default is 0.04. When set to 0, the ref model is not loaded. +- micro_batch_size: Batch size per device, default is 1. +- global_batch_size: Total batch size, equivalent to `micro_batch_size * data parallel size * gradient accumulation steps`. Default is 16. +- steps_per_generation: Number of optimization steps per generation round, i.e., the ratio of sampling batch size to global_batch_size. Default is 1. +- generation_batch_size: Sampling batch size, must be a multiple of global_batch_size. Default equals global_batch_size * steps_per_generation. +- num_generations: Number of samples per prompt, the G value in the paper, default is 8. +- reward_funcs: GRPO algorithm reward functions. Options include `accuracy`, `format`, `cosine`, `repetition`, and `soft_overlong`. See swift/plugin/orm.py. You can also customize your own reward functions in the plugin. Default is `[]`. +- reward_weights: Weights for each reward function. Must match the total number of reward functions and reward models. Default is None, meaning all rewards have equal weights of `1.0`. + - Tip: If GRPO training includes `--reward_model`, it is added at the end of the reward functions. +- loss_type: Loss normalization type. Options are `['grpo', 'bnpo', 'dr_grpo']`. Default is `'grpo'`. See this [PR](https://github.com/huggingface/trl/pull/3256#discussion_r2033213348) for details. +- log_completions: Whether to log model-generated content during training. Default is False. +- vllm_mode: vLLM integration mode. Options are `server` and `colocate`. Server mode uses the vLLM server launched by `swift rollout` for sampling, while colocate mode deploys vLLM within the program. When using server mode: +- vllm_mode server parameters: + - vllm_server_base_url: Base URL of the vLLM server (e.g., http://local_host:8000). Default is None. When set, host and port settings are ignored. + - vllm_server_host: vLLM server host address. Default is None. + - vllm_server_port: vLLM server port. Default is 8000. + - vllm_server_timeout: Timeout for connecting to the vLLM server. Default is 240s. + - vllm_server_pass_dataset: Pass additional dataset information to the vLLM server for multi-round training. + - async_generate: Asynchronous rollout to improve training speed. Note: When enabled, sampling uses the model from the previous round update, and multi-round scenarios are not supported. Default is `false`. + - SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE: Environment variable for controlling the bucket size during weight synchronization. Applicable to full-parameter training in Server Mode. Unit is MB, default value is 512 MB. +- vllm_mode colocate parameters (for more parameter support, refer to [vLLM parameters](#vllm-parameters)): + - vllm_gpu_memory_utilization: vLLM passthrough parameter. Default is 0.9. + - vllm_max_model_len: vLLM passthrough parameter. Default is None. + - vllm_enforce_eager: vLLM passthrough parameter. Default is False. + - vllm_limit_mm_per_prompt: vLLM passthrough parameter. Default is None. + - vllm_enable_prefix_caching: vLLM passthrough parameter. Default is True. + - vllm_tensor_parallel_size: Tensor parallel size. Default is `1`. + - vllm_enable_lora: Support loading LoRA adapters in the vLLM Engine. Default is False. Used to accelerate weight synchronization in LoRA training. See [documentation](../Instruction/GRPO/GetStarted/GRPO.md#weight-synchronization-acceleration) for details. + - sleep_level: Release vLLM GPU memory during training. Options are `[0, 1, 2]`. Default is 0, meaning no release. + - offload_optimizer: Whether to offload optimizer parameters during vLLM inference. Default is False. + - offload_model: Whether to offload the model during vLLM inference. Default is False. +- num_iterations: Number of updates per data sample, the $\mu$ value in the [GRPO paper](https://arxiv.org/abs/2402.03300). Default is 1. +- epsilon: Clip coefficient. Default is 0.2. +- epsilon_high: Upper clip coefficient. Default is None. When set, together with epsilon, forms the clipping range `[epsilon, epsilon_high]`. +- dynamic_sample: Filter out data with zero reward standard deviation within groups and sample additional new data. Default is False. +- max_resample_times: Limit the number of resampling times under dynamic_sample setting. Default is 3. +- overlong_filter: Skip overlong truncated samples, which do not participate in loss calculation. Default is False. +- delta: Bilateral GRPO upper bound clipping value from the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291). If set, it is recommended to be greater than 1 + epsilon. Default is None. +- importance_sampling_level: Controls importance sampling ratio calculation. Options are `token` and `sequence`. In `token` mode, the original log probability ratio for each token is preserved. In `sequence` mode, the log probability ratios of all valid tokens in the sequence are averaged. The [GSPO paper](https://www.arxiv.org/abs/2507.18071) uses sequence-level calculation to stabilize training. Default is `token`. +- scale_rewards: Specifies the reward scaling strategy. Options include `group` (scale by within-group standard deviation), `batch` (scale by batch-wide standard deviation), and `none` (no scaling). In ms-swift < 3.10, this parameter is boolean, where `true` corresponds to `group` and `false` corresponds to `none`. The default value is bound to `advantage_estimator`: `grpo` corresponds to `group`, `rloo` corresponds to `none`, and `reinforce_plus_plus` corresponds to `batch`. + +Built-in reward function parameters refer to the [documentation](../Instruction/Command-line-parameters.md#reward-function-parameters). + ## Export Parameters This section introduces the parameters for `megatron export` (requires "ms-swift>=3.10"). To use the `swift export` command for exporting, please refer to the [ms-swift Command Line Parameters Documentation](../Instruction/Command-line-parameters.md#export-arguments). Compared to `swift export`, `megatron export` supports distributed and multi-node exporting. Megatron export parameters inherit from Megatron parameters and basic parameters. From 027ca573322388a5f12028e24a365b5747effc64 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 14 Nov 2025 15:39:06 +0800 Subject: [PATCH 70/83] update doc & args check --- .../Instruction/Command-line-parameters.md | 9 ---- docs/source/Megatron-SWIFT/GRPO.md | 2 + .../Instruction/Command-line-parameters.md | 9 ---- .../Instruction/GRPO/AdvancedResearch/GSPO.md | 2 +- docs/source_en/Instruction/Use-tuners.md | 2 +- docs/source_en/Megatron-SWIFT/GRPO.md | 2 + swift/megatron/argument/megatron_args.py | 43 ++++++++++++------- 7 files changed, 33 insertions(+), 36 deletions(-) diff --git a/docs/source/Instruction/Command-line-parameters.md b/docs/source/Instruction/Command-line-parameters.md index 2eda79118f..40f9fda195 100644 --- a/docs/source/Instruction/Command-line-parameters.md +++ b/docs/source/Instruction/Command-line-parameters.md @@ -547,15 +547,6 @@ reward模型参数将在PPO、GRPO中使用。 #### GRPO参数 - beta: KL正则系数,默认为0.04,设置为0时不加载ref model。 -- epsilon: clip 系数,默认为0.2。 -- epsilon_high: upper clip 系数,默认为None,设置后与epsilon共同构成[epsilon, epsilon_high]裁剪范围。 -- delta: [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291)中双侧 GRPO 上界裁剪值。若设置,建议大于 1 + epsilon。默认为None。 -- overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。 -- dynamic_sample:筛除group内奖励标准差为0的数据,额外采样新数据,默认为False。 -- max_resample_times:dynamic_sample设置下限制重采样次数,默认3次。 -- top_entropy_quantile: 仅对熵值处于前指定分位的 token 参与损失计算,默认为1.0,即不过滤低熵 token,具体参考[文档](./GRPO/AdvancedResearch/entropy_mask.md) -- log_entropy: 记录训练中的熵值变化动态,默认为False,具体参考[文档](./GRPO/GetStarted/GRPO.md#logged-metrics) -- importance_sampling_level: 控制重要性采样比计算,可选项为 `token` 、 `sequence` 和 `sequence_token`,默认为`token`。具体参考[GSPO文档](./GRPO/AdvancedResearch/GSPO.md) - per_device_train_batch_size: 每个设备训练批量大小,在GRPO中,指 completion 的批次大小。 - per_device_eval_batch_size: 每个设备评估批量大小,在GRPO中,指 completion 的批次大小。 - generation_batch_size: 采样completion批量大小,需要是 num_processes * per_device_train_batch_size 的倍数,默认等于 per_device_batch_size * gradient_accumulation_steps * num_processes diff --git a/docs/source/Megatron-SWIFT/GRPO.md b/docs/source/Megatron-SWIFT/GRPO.md index 79c0080be8..a8aa4df0e4 100644 --- a/docs/source/Megatron-SWIFT/GRPO.md +++ b/docs/source/Megatron-SWIFT/GRPO.md @@ -21,6 +21,8 @@ Megatron GRPO 当前已支持以下功能: - **KL 散度计入奖励**(`kl_in_reward`) - **虚拟流水线并行**(VPP) - **参考模型同步更新**(`sync_ref_model`) +- **Async Generate** (`async_generate`) +- **num_iterations** - **日志同步 SwanLab** ⚠️ 注意:以下参数在 Megatron GRPO 中不生效: diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 5bfa2adade..6ce0658114 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -559,15 +559,6 @@ The meanings of the following parameters can be referenced [here](https://huggin #### GRPO Arguments - beta: KL regularization coefficient; default 0.04. Setting it to 0 disables the reference model. -- epsilon: epsilon value for clipping. Default is 0.2. -- epsilon_high: Upper clip coefficient, default is None. When set, it forms a clipping range of [epsilon, epsilon_high] together with epsilon. -- delta: Delta value for the upper clipping bound in two-sided GRPO. Recommended to be > 1 + epsilon. This method was introduced in the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291). -- overlong_filter: Skip overlong truncated samples, which will not be included in loss calculation. Default is False. -- dynamic_sample: Exclude data within the group where the reward standard deviation is 0, and additionally sample new data. Default is False. -- max_resample_times: Under the dynamic_sample setting, limit the number of resampling attempts to a maximum of 3. Default is 3 times. -- top_entropy_quantile: Only tokens whose entropy ranks within the specified top quantile are included in the loss calculation. The default is 1.0, which means low-entropy tokens are not filtered. For details, refer to the [documentation](./GRPO/AdvancedResearch/entropy_mask.md). -- log_entropy: Logs the entropy values during training. The default is False. For more information, refer to the [documentation](./GRPO/GetStarted/GRPO.md#logged-metrics). -- importance_sampling_level: Controls how the importance sampling ratio is computed. Options are `token` and `sequence`. In `token` mode, the raw per-token log-probability ratios are used. In `sequence` mode, the log-probability ratios of all valid tokens in the sequence are averaged to produce a single ratio per sequence. The [GSPO paper](https://www.arxiv.org/abs/2507.18071) uses sequence-level importance sampling to stabilize training. The default is `token`. - per_device_train_batch_size: The training batch size per device. In GRPO, this refers to the batch size of completions during training. - per_device_eval_batch_size: The evaluation batch size per device. In GRPO, this refers to the batch size of completions during evaluation. - generation_batch_size: Batch size to use for generation. It defaults to the effective training batch size: per_device_train_batch_size * num_processes * gradient_accumulation_steps` diff --git a/docs/source_en/Instruction/GRPO/AdvancedResearch/GSPO.md b/docs/source_en/Instruction/GRPO/AdvancedResearch/GSPO.md index 2f8ec7ae54..5062188bf2 100644 --- a/docs/source_en/Instruction/GRPO/AdvancedResearch/GSPO.md +++ b/docs/source_en/Instruction/GRPO/AdvancedResearch/GSPO.md @@ -1,6 +1,6 @@ # Group Sequence Policy Optimization -**Version Requirement**: ms-swift>=3.7 +**Version Requirement**: ms-swift>=3.8 In [Group Sequence Policy Optimization](https://www.arxiv.org/abs/2507.18071), it is pointed out that GRPO computes importance sampling weights at the token level. However, this approach is problematic: since each token is only sampled once, it cannot realize effective distribution correction, and instead introduces high-variance noise during training, which can easily lead to unstable gradient estimates and even training collapse. Therefore, the paper argues that the unit of the objective function should be consistent with that of the reward. Since the reward is typically given at the sequence level (i.e., for the entire generated response), it is more reasonable to perform off-policy correction and optimization at the sequence level rather than the token level. diff --git a/docs/source_en/Instruction/Use-tuners.md b/docs/source_en/Instruction/Use-tuners.md index f960591893..d1b4f2cb1d 100644 --- a/docs/source_en/Instruction/Use-tuners.md +++ b/docs/source_en/Instruction/Use-tuners.md @@ -15,7 +15,7 @@ Tuners refer to additional structural components attached to a model, aimed at r - Adapter: [Parameter-Efficient Transfer Learning for NLP](http://arxiv.org/abs/1902.00751) - Vision Prompt Tuning: [Visual Prompt Tuning](https://arxiv.org/abs/2203.12119) - Side: [Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks](https://arxiv.org/abs/1912.13503) -- Res-Tuning: [Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone](https://arxiv.org/abs/2310.19859) < [arXiv](https://arxiv.org/abs/2310.19859) | [Project Page](https://res-tuning.github.io/) | [Usage](ResTuning.md) > +- Res-Tuning: [Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone](https://arxiv.org/abs/2310.19859) < [arXiv](https://arxiv.org/abs/2310.19859) | [Project Page](https://res-tuning.github.io/) > - Tuners provided by [PEFT](https://github.com/huggingface/peft), such as AdaLoRA, DoRA, Fourierft, etc. ## Interface List diff --git a/docs/source_en/Megatron-SWIFT/GRPO.md b/docs/source_en/Megatron-SWIFT/GRPO.md index d72b872f17..3fa9dfb58d 100644 --- a/docs/source_en/Megatron-SWIFT/GRPO.md +++ b/docs/source_en/Megatron-SWIFT/GRPO.md @@ -21,6 +21,8 @@ The following parameters or features will be gradually supported in future versi - **KL Divergence in Reward** (`kl_in_reward`) - **Virtual Pipeline Parallelism** (VPP) - **Reference Model Synchronization** (`sync_ref_model`) +- **Async Generate** (`async_generate`) +- **num_iterations** - **SwanLab Logging Integration** ⚠️ **Note**: The following parameters are not effective in Megatron GRPO: diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 283159965b..5fa21eb5e6 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -52,7 +52,7 @@ class RLHFMegatronArgumentsMixin: top_k: int = 50 top_p: float = 0.9 repetition_penalty: float = 1. - use_vllm: bool = False + use_vllm: bool = True vllm_mode: Literal['server', 'colocate'] = 'colocate' vllm_enable_prefix_caching: bool = True @@ -63,6 +63,8 @@ class RLHFMegatronArgumentsMixin: vllm_limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 5, "video": 2}' vllm_disable_cascade_attn: bool = False sleep_level: Literal[0, 1, 2] = 0 + offload_optimizer: bool = False + offload_model: bool = False vllm_server_base_url: Optional[List[str]] = None vllm_server_host: Optional[List[str]] = None @@ -107,8 +109,6 @@ class RLHFMegatronArgumentsMixin: async_generate: bool = False move_model_batches: Optional[int] = None - offload_optimizer: bool = False - offload_model: bool = False # multi turn multi_turn_scheduler: Optional[str] = None @@ -147,7 +147,20 @@ def __post_init__(self): def _init_grpo(self): def _check_not_supported(): - pass + if self.async_generate: + raise ValueError('async_generate is not supported for Megatron GRPO right now') + if self.sync_ref_model: + raise ValueError('sync_ref_model is not supported for Megatron GRPO right now') + if not self.dataset_shuffle: + raise ValueError('dataset_shuffle false is not supported for Megatron GRPO') + if self.multi_turn_scheduler: + raise ValueError('multi_turn_scheduler is not supported for Megatron GRPO right now') + if self.log_entropy: + raise ValueError('log_entropy is not supported for Megatron GRPO right now') + if self.top_entropy_quantile < 1: + raise ValueError('top_entropy_quantile < 1 is not supported for Megatron GRPO right now') + if self.num_iterations > 1: + raise ValueError('num_iterations > 1 is not supported for Megatron GRPO right now') def _check_batch_params(): # Set default values if both are None @@ -195,8 +208,6 @@ def _check_batch_params(): # default loss_type if no loss_type is provided assert self.loss_type in ['grpo', 'bnpo', 'dr_grpo'], \ f'loss_type must be one of [grpo, bnpo, dr_grpo], but got {self.loss_type}' - if self.async_generate or not self.use_vllm: - self.sleep_level = 0 self.remove_unused_columns = False logger.info(f'Setting args.remove_unused_columns: {self.remove_unused_columns}') if self.truncation_strategy is None: @@ -218,16 +229,16 @@ def _check_batch_params(): if self.soft_max_length is None: self.soft_max_length = self.max_completion_length logger.info(f'Auto-configured soft_max_length = max_completion_length {self.max_completion_length}') - if self.use_vllm: - # set vllm mode - if self.vllm_server_host is not None or self.vllm_server_base_url is not None: - if self.vllm_mode != 'server': - self.vllm_mode = 'server' - logger.warning('set vllm_mode to `server` since vllm server host/base_url is provided') - else: - if self.vllm_mode != 'colocate': - self.vllm_mode = 'colocate' - logger.warning('set vllm_mode to `colocate` since vllm_server_host is not provided') + assert self.use_vllm, 'use_vllm must be True for Megatron GRPO' + # set vllm mode + if self.vllm_server_host is not None or self.vllm_server_base_url is not None: + if self.vllm_mode != 'server': + self.vllm_mode = 'server' + logger.warning('set vllm_mode to `server` since vllm server host/base_url is provided') + else: + if self.vllm_mode != 'colocate': + self.vllm_mode = 'colocate' + logger.warning('set vllm_mode to `colocate` since vllm_server_host is not provided') @dataclass From 360da420176a0a38df54ae82de6483abfb4a9d7f Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 14 Nov 2025 15:43:07 +0800 Subject: [PATCH 71/83] clean up --- swift/megatron/argument/megatron_args.py | 11 ----------- swift/megatron/argument/train_args.py | 4 +--- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 5fa21eb5e6..afc911fb9e 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -19,7 +19,6 @@ @dataclass class RLHFMegatronArgumentsMixin: rlhf_type: Literal['dpo', 'kto', 'grpo', 'rm'] = None - perform_initialization: bool = True ref_load: Optional[str] = None ref_adapter_load: Optional[str] = None @@ -230,15 +229,6 @@ def _check_batch_params(): self.soft_max_length = self.max_completion_length logger.info(f'Auto-configured soft_max_length = max_completion_length {self.max_completion_length}') assert self.use_vllm, 'use_vllm must be True for Megatron GRPO' - # set vllm mode - if self.vllm_server_host is not None or self.vllm_server_base_url is not None: - if self.vllm_mode != 'server': - self.vllm_mode = 'server' - logger.warning('set vllm_mode to `server` since vllm server host/base_url is provided') - else: - if self.vllm_mode != 'colocate': - self.vllm_mode = 'colocate' - logger.warning('set vllm_mode to `colocate` since vllm_server_host is not provided') @dataclass @@ -384,7 +374,6 @@ class MegatronArguments(ExtraMegatronArguments): dataloader_type: Literal['single', 'cyclic', 'external'] = 'cyclic' manual_gc: bool = False manual_gc_interval: int = 0 - use_mbridge: bool = False # learning rate lr: Optional[float] = None diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py index 47a4772bf3..8a100a9380 100644 --- a/swift/megatron/argument/train_args.py +++ b/swift/megatron/argument/train_args.py @@ -5,9 +5,7 @@ import json from swift.llm.argument.base_args import to_abspath -from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_last_rank, is_master -from ..model import get_megatron_model_meta -from .megatron_args import MegatronArguments, RLHFMegatronArgumentsMixin +from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_last_rank from .megatron_base_args import MegatronBaseArguments logger = get_logger() From 6259a346821befdaecb4b9f16cfeb251ce55c34c Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 14 Nov 2025 15:48:40 +0800 Subject: [PATCH 72/83] clean up --- swift/megatron/argument/megatron_args.py | 3 +-- swift/megatron/trainers/base.py | 5 +---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index afc911fb9e..80531833c5 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -10,8 +10,7 @@ from transformers.utils.versions import require_version from swift.llm import get_model_info_meta -from swift.llm.argument.base_args import to_abspath -from swift.utils import get_current_device, get_dist_setting, get_logger, is_master, json_parse_to_dict +from swift.utils import get_dist_setting, get_logger, json_parse_to_dict logger = get_logger() diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 43ce04c173..5ac2616611 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -4,7 +4,7 @@ import shutil import time from abc import ABC, abstractmethod -from contextlib import contextmanager, nullcontext +from contextlib import contextmanager from datetime import datetime from typing import Callable, Dict, List, Literal, Optional @@ -29,10 +29,7 @@ from megatron.training.training import num_floating_point_operations from megatron.training.utils import reduce_max_stat_across_model_parallel_group, report_memory, unwrap_model from packaging import version -from peft.utils import ModulesToSaveWrapper -from torch.distributed.nn import all_reduce from tqdm.auto import tqdm -from transformers.utils import ContextManagers from swift.llm import Template, dynamic_gradient_checkpointing from swift.plugin import MeanMetric From 5fdde02df74edca35a2aab468527ab5a4bcde673 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 14 Nov 2025 15:54:20 +0800 Subject: [PATCH 73/83] clean up --- swift/megatron/trainers/grpo_trainer.py | 29 ++++++++----------------- swift/megatron/trainers/utils.py | 19 +++------------- 2 files changed, 12 insertions(+), 36 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 1416dd7184..7dbb06015e 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -37,8 +37,8 @@ from ..utils import forward_step_helper from .rlhf_mixin import MegatronRLHFTrainer from .utils import (gather, gather_object, get_swift_datasets_provider, load_megatron_model_to_gpu, - load_megatron_optimizer, log_gpu_memory, offload_megatron_model_to_cpu, offload_megatron_optimizer, - patch_profiling_context, patch_profiling_decorator, profiling_context) + load_megatron_optimizer, offload_megatron_model_to_cpu, offload_megatron_optimizer, + profiling_context, profiling_decorator) if is_wandb_available(): import wandb @@ -172,7 +172,6 @@ def _prepare_rollout_engine(self): self.engine = self.prepare_vllm() if self.args.sleep_level > 0: self.engine.engine.sleep(self.args.sleep_level) - log_gpu_memory('after sleep vLLM engine') set_expandable_segments(True) else: raise ValueError(f'Invalid vllm_mode: {self.vllm_mode}') @@ -207,7 +206,7 @@ def prepare_vllm(self): self._buffered_inputs = None return engine - @patch_profiling_decorator + @profiling_decorator def _move_model_to_vllm(self): # Handle LoRA: merge adapters before exporting weights is_lora_training = self.args.train_type == 'lora' @@ -244,7 +243,7 @@ def _export_and_load_weights(self): For server mode: Process weights in buckets to avoid memory spikes. """ # Export weights returns an iterator - with patch_profiling_context(self, 'export_weights'): + with profiling_context(self, 'export_weights'): weight_iterator = self.bridge.export_weights(self.unwrapped_models) if self.vllm_mode == 'colocate': @@ -565,13 +564,13 @@ def _get_encoded_batch(rollout_batch, advantages): micro_batch_data = self._maybe_replace_response_token(micro_batch_data) micro_batch_advantages = total_advantages[idx:idx + self.micro_batch_size] micro_batch_data = _get_encoded_batch(micro_batch_data, micro_batch_advantages) - with patch_profiling_context(self, 'compute_ref_old_logps'): + with profiling_context(self, 'compute_ref_old_logps'): micro_batch_data = self._maybe_compute_logps(micro_batch_data) mini_batch_data.append(micro_batch_data) return mini_batch_data - @patch_profiling_decorator + @profiling_decorator def _generate_completions(self, batch): """ Generate completions for a batch of rollout data using vLLM engine. @@ -593,9 +592,7 @@ def _generate_completions(self, batch): wake_up_params = inspect.signature(self.engine.engine.wake_up).parameters # Load weights only (faster and reduces memory peak) kwargs = {'tags': ['weights']} if 'tags' in wake_up_params else {} - log_gpu_memory(f'before wake up vLLM engine {kwargs}') self.engine.engine.wake_up(**kwargs) - log_gpu_memory(f'after wake up vLLM engine {kwargs}') # Step 2: Load model weights if self._step != self._last_loaded_step: @@ -608,9 +605,7 @@ def _generate_completions(self, batch): and 'tags' in inspect.signature(self.engine.engine.wake_up).parameters): aggressive_empty_cache() set_expandable_segments(False) - log_gpu_memory('before wake up vLLM engine kv_cache') self.engine.engine.wake_up(tags=['kv_cache']) - log_gpu_memory('after wake up vLLM engine kv_cache') # Step3: Rollout outputs: List[RolloutOutput] = self._rollout(batch) @@ -618,11 +613,9 @@ def _generate_completions(self, batch): # Step4: Sleep to release memory if self.vllm_mode == 'colocate' and self.args.sleep_level > 0: self.engine.engine.reset_prefix_cache() - log_gpu_memory('before sleep vLLM engine') self.engine.engine.sleep(level=self.args.sleep_level) aggressive_empty_cache() set_expandable_segments(True) - log_gpu_memory('after sleep vLLM engine') batch = self.postprocess_rollout_data(batch, outputs) return batch @@ -764,7 +757,7 @@ def _colocate_rollout(self, batch, request_config: RequestConfig): return outputs - @patch_profiling_decorator + @profiling_decorator def _score_completions(self, inputs: DataType) -> torch.Tensor: """Score completions using all reward functions. @@ -1001,7 +994,7 @@ def build_pretraining_data_loader(*_args, **kwargs): finally: training.build_pretraining_data_loader = origin_build_pretraining_data_loader - @patch_profiling_decorator + @profiling_decorator def forward_step(self, data_iterator, model): # train_batch_size # return: output_tensor, loss_func @@ -1017,7 +1010,7 @@ def forward_step(self, data_iterator, model): output_tensor = model(**inputs) return output_tensor, partial(self.loss_func, data=data) - @patch_profiling_decorator + @profiling_decorator def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): advantages = data['advantages'] labels = data['labels'] @@ -1182,10 +1175,8 @@ def offload_context(self): offload_megatron_model_to_cpu(self.wrapped_models) if hasattr(self, 'ref_models') and self.ref_models: offload_megatron_model_to_cpu(self.ref_models) - log_gpu_memory('after offload model to cpu') if getattr(self, 'optimizer', None) and self.args.offload_optimizer: offload_megatron_optimizer(self.optimizer) - log_gpu_memory('after offload optimizer to cpu') try: yield @@ -1195,10 +1186,8 @@ def offload_context(self): load_megatron_model_to_gpu(self.wrapped_models) if hasattr(self, 'ref_models') and self.ref_models: load_megatron_model_to_gpu(self.ref_models) - log_gpu_memory('after load model to gpu') if getattr(self, 'optimizer', None) and self.args.offload_optimizer: load_megatron_optimizer(self.optimizer) - log_gpu_memory('after load optimizer to gpu') def inputs2requests(self, inputs: DataType) -> List[RolloutInferRequest]: """Convert raw input data into RolloutInferRequest objects""" diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index c0e19ccb77..0649d2ea44 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -3,7 +3,7 @@ import gc import time from contextlib import contextmanager -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional import torch from accelerate.utils import gather as hf_gather @@ -133,24 +133,11 @@ def profiling_context(trainer, name: str): wandb_writer.log(profiling_metrics) -@contextmanager -def patch_profiling_context(trainer, name: str): - start_time = time.perf_counter() - yield - end_time = time.perf_counter() - duration = end_time - start_time - - profiling_metrics = {f'profiling/Time taken: {trainer.__class__.__name__}.{name}': duration} - wandb_writer = get_wandb_writer() - if wandb_writer and trainer.is_main_process: - wandb_writer.log(profiling_metrics) - - -def patch_profiling_decorator(func): +def profiling_decorator(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): - with patch_profiling_context(self, func.__name__): + with profiling_context(self, func.__name__): return func(self, *args, **kwargs) return wrapper From b5be0ce3aa93ca94a4bfbdda7b2922ea5c4d68a8 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 14 Nov 2025 16:01:46 +0800 Subject: [PATCH 74/83] clean up --- swift/megatron/trainers/grpo_trainer.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 7dbb06015e..c318936fdf 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -43,11 +43,6 @@ if is_wandb_available(): import wandb -try: - from trl.trainer.utils import entropy_from_logits -except ImportError: - from swift.trainers.rlhf_trainer.utils import entropy_from_logits - logger = get_logger() @@ -66,8 +61,6 @@ def __init__(self, args: MegatronRLHFArguments, template: Template, **kwargs): self._prepare_scheduler() # TODO self._prepare_rollout_engine() - self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} - def train(self, train_dataset, val_dataset, data_collator): # Store dataset provider for lazy resample iterator initialization if self.dynamic_sample: @@ -1345,6 +1338,8 @@ def log(self, data: dict[str, Any], step: int | None = None, commit: bool | None Run.log = log + self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} + def _apply_chat_template_to_messages_list(self, messages_list: DataType): prompts_text = [] for messages in messages_list: From d8c7c3b912fe37c3e803d79cbe3845f30789805c Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 14 Nov 2025 16:05:12 +0800 Subject: [PATCH 75/83] clean up --- swift/megatron/trainers/grpo_trainer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index c318936fdf..9792bf030f 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -19,8 +19,6 @@ from megatron.core import mpu from megatron.core.rerun_state_machine import RerunDataIterator from megatron.training import get_args, get_wandb_writer, training -from torch._tensor import Tensor -from torch.distributed.nn import all_reduce from trl.trainer.grpo_trainer import nanstd from vllm.distributed import parallel_state as vllm_ps @@ -139,8 +137,7 @@ def _prepare_rollout_engine(self): self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode self.use_vllm = args.use_vllm - self.async_generate = args.async_generate - self.use_fast_infer = self.use_vllm + self.async_generate = args.async_generate # TODO self.vllm_use_async_engine = False self.enable_offload = False self.use_gym_env = False From 2aaa1e515c3a7bffb534aeff147c6114d30964d8 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 14 Nov 2025 16:17:37 +0800 Subject: [PATCH 76/83] align scale_rewards --- swift/megatron/argument/megatron_args.py | 9 ++++++- swift/megatron/trainers/grpo_trainer.py | 33 +++++++++++++++++++----- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 80531833c5..eac250f9cf 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -90,12 +90,15 @@ class RLHFMegatronArgumentsMixin: overlong_filter: bool = False # Dr. GRPO, https://arxiv.org/abs/2503.20783 - scale_rewards: bool = True + scale_rewards: Literal['none', 'group', 'batch'] = 'group' wandb_log_unique_prompts: Optional[bool] = None log_completions: bool = False # ─────────────────────────── Not Supported Yet ─────────────────────────── + # RLOO / REINFORCE++ + advantage_estimator: Literal['grpo', 'rloo', 'reinforce_plus_plus'] = 'grpo' + kl_in_reward: bool = False # reward model reward_model: Optional[List[str]] = None reward_model_plugin: Optional[List[str]] = None @@ -159,6 +162,10 @@ def _check_not_supported(): raise ValueError('top_entropy_quantile < 1 is not supported for Megatron GRPO right now') if self.num_iterations > 1: raise ValueError('num_iterations > 1 is not supported for Megatron GRPO right now') + if self.kl_in_reward: + raise ValueError('kl_in_reward is not supported for Megatron GRPO right now') + if self.advantage_estimator != 'grpo': + raise ValueError('advantage_estimator must be grpo for Megatron GRPO right now') def _check_batch_params(): # Set default values if both are None diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 9792bf030f..53269a8483 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -105,6 +105,11 @@ def _init_grpo_params(self): self.max_resample_times = args.max_resample_times self.overlong_filter = args.overlong_filter + # Dr. GRPO / RLOO / REINFORCE++ + self.scale_rewards = args.scale_rewards + self.advantage_estimator = args.advantage_estimator # TODO + self.kl_in_reward = args.kl_in_reward # TODO + # Entropy mask settings, TODO self.log_entropy = args.log_entropy self.compute_entropy = self.log_entropy or self.top_entropy_quantile < 1.0 @@ -795,9 +800,9 @@ def _compute_rewards_per_func(self, batch: DataType) -> torch.Tensor: def _compute_advantages(self, batch: DataType, rewards_per_func: torch.Tensor) -> torch.Tensor: """Compute advantages for RL training.""" - def maybe_normalize_advantages(advantages: torch.Tensor, rewards_std: torch.Tensor) -> torch.Tensor: + def normalize_advantages(advantages: torch.Tensor, rewards_std: torch.Tensor) -> torch.Tensor: """Normalize advantages if configured; otherwise, return as-is.""" - if self.args.scale_rewards: + if self.scale_rewards != 'none': return advantages / (rewards_std + 1e-4) return advantages @@ -806,16 +811,28 @@ def maybe_normalize_advantages(advantages: torch.Tensor, rewards_std: torch.Tens total_rewards_per_func = gather(rewards_per_func) rewards = (total_rewards_per_func * self.reward_weights.unsqueeze(0)).nansum(dim=1) grouped_rewards = rewards.view(-1, self.num_generations) + + # Compute group statistics group_rewards_mean = grouped_rewards.mean(dim=1) - group_rewards_std = grouped_rewards.std(dim=1) # Broadcast stats back to the original shape group_rewards_mean = group_rewards_mean.repeat_interleave(self.num_generations) - group_rewards_std = group_rewards_std.repeat_interleave(self.num_generations) # Compute advantages relative to group mean advantages = rewards - group_rewards_mean - advantages = maybe_normalize_advantages(advantages, group_rewards_std) + + # Normalize advantages based on scale_rewards setting + if self.scale_rewards == 'batch': + # Global batch-level normalization + rewards_std = rewards.std().expand_as(rewards) + elif self.scale_rewards == 'group': + # Group-level normalization (default) + rewards_std = grouped_rewards.std(dim=1).repeat_interleave(self.num_generations) + else: # 'none' + rewards_std = None + + if rewards_std is not None: + advantages = normalize_advantages(advantages, rewards_std) def log_rewards_metrics(rewards: torch.Tensor, rewards_per_func_for_metrics: torch.Tensor): """Log reward statistics for monitoring. Only log once per unique request_id.""" @@ -823,7 +840,11 @@ def log_rewards_metrics(rewards: torch.Tensor, rewards_per_func_for_metrics: tor # rewards_per_func_for_metrics: [prompt_batch_size*self.num_generations, self.num_reward_funcs] group_rewards = rewards.view(-1, self.num_generations) rewards_mean = group_rewards.mean(-1).mean().item() - rewards_std = group_rewards.std(-1).mean().item() + # Compute std based on scale_rewards setting for logging + if self.scale_rewards in ['group', 'none']: + rewards_std = group_rewards.std(-1).mean().item() + elif self.scale_rewards == 'batch': + rewards_std = rewards.std().item() is_std_zero = torch.isclose(group_rewards.std(dim=1), torch.zeros_like(group_rewards.std(dim=1))) self._metrics[mode]['reward'].append(rewards_mean) From 29ecb32d2d76934f097c82262fed3978ae787e29 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 14 Nov 2025 16:32:27 +0800 Subject: [PATCH 77/83] aggressive_empty_cache before wake up weights --- examples/megatron/grpo/dense_server.sh | 6 +++--- swift/trainers/rlhf_trainer/rollout_mixin.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/megatron/grpo/dense_server.sh b/examples/megatron/grpo/dense_server.sh index 9505c9d37e..7361d2dcd0 100644 --- a/examples/megatron/grpo/dense_server.sh +++ b/examples/megatron/grpo/dense_server.sh @@ -35,9 +35,9 @@ megatron rlhf \ --external_plugins examples/train/grpo/plugin/plugin.py \ --reward_funcs external_r1v_acc format \ --use_vllm true \ - --vllm_mode colocate \ - --vllm_gpu_memory_utilization 0.7 \ - --vllm_max_model_len 10240 \ + --vllm_mode server \ + --vllm_server_host 127.0.0.1 \ + --vllm_server_port 8000 \ --max_length 8192 \ --max_completion_length 2048 \ --train_type full \ diff --git a/swift/trainers/rlhf_trainer/rollout_mixin.py b/swift/trainers/rlhf_trainer/rollout_mixin.py index 17d8210021..3cb154a94c 100644 --- a/swift/trainers/rlhf_trainer/rollout_mixin.py +++ b/swift/trainers/rlhf_trainer/rollout_mixin.py @@ -637,6 +637,7 @@ def _fast_infer(self, inputs: DataType) -> DataType: if self.engine.inner_model_executor.is_sleeping: wake_up_params = inspect.signature(self.engine.engine.wake_up).parameters kwargs = {'tags': ['weights']} if 'tags' in wake_up_params else {} + aggressive_empty_cache() self.engine.engine.wake_up(**kwargs) if self.state.global_step != self._last_loaded_step: From ad00c5c6162a69c5b1ae97ef64f0a0f8f049391c Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 14 Nov 2025 17:09:30 +0800 Subject: [PATCH 78/83] docs --- README.md | 3 ++- README_CN.md | 3 ++- docs/source/Megatron-SWIFT/Command-line-parameters.md | 4 ++-- docs/source/Megatron-SWIFT/Multimodal-Model.md | 2 +- docs/source/Megatron-SWIFT/Quick-start.md | 1 + docs/source_en/Megatron-SWIFT/Command-line-parameters.md | 6 +++--- docs/source_en/Megatron-SWIFT/Multimodal-Model.md | 2 +- docs/source_en/Megatron-SWIFT/Quick-start.md | 3 ++- 8 files changed, 14 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index b263879aec..455ff142d3 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ You can contact us and communicate with us by adding our group: - **Quantization Training**: Supports training quantized models like BNB, AWQ, GPTQ, AQLM, HQQ, EETQ. - 🍊 **RLHF Training**: Supports human alignment training methods such as DPO, GRPO, RM, PPO, GKD, KTO, CPO, SimPO, ORPO for both pure text and multi-modal large models. - 🍓 **Multi-Modal Training**: Supports training on different modalities like images, videos, and audio, for tasks like VQA, captioning, OCR, and grounding. -- 🥥 **Megatron Parallelism**: Supports accelerating CPT/SFT/DPO/KTO/RM using Megatron parallelism techniques, currently compatible with 200+ pure text large models, 100+ multi-modal large models. +- 🥥 **Megatron Parallelism**: Supports accelerating CPT/SFT/GRPO/DPO/KTO/RM using Megatron parallelism techniques, currently compatible with 200+ pure text large models, 100+ multi-modal large models. - **Interface Training**: Provides capabilities for training, inference, evaluation, quantization through an interface, completing the whole large model pipeline. - **Plugin and Extension**: Supports custom model and dataset extensions, as well as customization of components like loss, metric, trainer, loss-scale, callback, optimizer. - 🍉 **Toolbox Capabilities**: Offers not only training support for large models and multi-modal large models but also covers the entire process of inference, evaluation, quantization, and deployment. @@ -75,6 +75,7 @@ You can contact us and communicate with us by adding our group: ## 🎉 News +- 🎁 2025.11.14: Megatron GRPO is now available! Check out the [docs](./docs/source_en/Megatron-SWIFT/GRPO.md) and [examples](examples/megatron/grpo). - 🎁 2025.11.04: Support for [Mcore-Bridge](docs/source_en/Megatron-SWIFT/Mcore-Bridge.md), making Megatron training as simple and easy to use as transformers. - 🎁 2025.10.28: Ray [here](docs/source_en/Instruction/Ray.md). - 🎁 2025.10.28: Support [use yaml](examples/yaml) to configure command line parameters. diff --git a/README_CN.md b/README_CN.md index da2b914169..08a7f1b93d 100644 --- a/README_CN.md +++ b/README_CN.md @@ -62,7 +62,7 @@ - **量化训练**:支持对BNB、AWQ、GPTQ、AQLM、HQQ、EETQ量化模型进行训练。 - 🍊 **RLHF训练**:支持纯文本大模型和多模态大模型的DPO、GRPO、RM、PPO、GKD、KTO、CPO、SimPO、ORPO等人类对齐训练方法。 - 🍓 **多模态训练**:支持对图像、视频和语音不同模态模型进行训练,支持VQA、Caption、OCR、Grounding任务的训练。 -- 🥥 **Megatron并行技术**:支持使用Megatron并行技术对CPT/SFT/DPO/KTO/RM进行加速,现支持200+纯文本大模型和100+多模态大模型。 +- 🥥 **Megatron并行技术**:支持使用Megatron并行技术对CPT/SFT/GRPO/DPO/KTO/RM进行加速,现支持200+纯文本大模型和100+多模态大模型。 - **界面训练**:以界面的方式提供训练、推理、评测、量化的能力,完成大模型的全链路。 - **插件化与拓展**:支持自定义模型和数据集拓展,支持对loss、metric、trainer、loss-scale、callback、optimizer等组件进行自定义。 - 🍉 **工具箱能力**:不仅提供大模型和多模态大模型的训练支持,还涵盖其推理、评测、量化和部署全流程。 @@ -71,6 +71,7 @@ - **模型量化**:支持AWQ、GPTQ、FP8和BNB的量化导出,导出的模型支持使用vLLM/SGLang/LmDeploy推理加速,并支持继续训练。 ## 🎉 新闻 +- 🎁 2025.11.14: Megatron GRPO现已支持!查看[文档](./docs/source/Megatron-SWIFT/GRPO.md)和[示例](examples/megatron/grpo)。 - 🎁 2025.11.04: 支持[Mcore-Bridge](docs/source/Megatron-SWIFT/Mcore-Bridge.md),使Megatron训练像transformers一样简单易用。 - 🎁 2025.10.28: Ray [已支持](docs/source/Instruction/Ray.md)。 - 🎁 2025.10.28: 已支持[使用yaml](examples/yaml)配置命令行参数。 diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index 8e204628cd..82c05590ac 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -252,7 +252,7 @@ lora训练: - model: safetensors权重的model_id或者model_path。默认为None。 - model_type: 模型类型。介绍参考[ms-swift命令行参数文档](../Instruction/Command-line-parameters.md)。 - adapters: safetensors格式的LoRA增量权重的adapter_id或者adapter_path。默认为`[]`。 -- ref_model: ref_model safetensors权重的model_id或者model_path。采用dpo、kto算法且使用全参数训练时需要传入。默认为None,设置为`--model`。 +- ref_model: ref_model safetensors权重的model_id或者model_path。采用grpo、dpo、kto算法且使用全参数训练时需要传入。默认为None,设置为`--model`。 - ref_adapters: ref_adapters safetensors权重的adapter_id或者adapter_path的列表(目前只支持长度为1),默认为`[]`。 - use_hf: 控制模型下载、数据集下载、模型推送使用ModelScope还是HuggingFace。默认为False,使用ModelScope。 - hub_token: hub token. modelscope的hub token可以查看[这里](https://modelscope.cn/my/myaccesstoken)。默认为None。 @@ -295,7 +295,7 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用 ## RLHF参数 除了继承训练参数外,还支持以下参数: -- 🔥rlhf_type: 默认为'dpo'。目前可选择为'dpo'、'kto'和'rm'。 +- 🔥rlhf_type: 默认为'dpo'。目前可选择为'dpo'、'grpo'、'kto'和'rm'。 - loss_scale: 覆盖[基本参数](../Instruction/Command-line-parameters.md)中的loss_scale。默认为'last_round'。 - calculate_per_token_loss: 覆盖Megatron参数,默认为False。 diff --git a/docs/source/Megatron-SWIFT/Multimodal-Model.md b/docs/source/Megatron-SWIFT/Multimodal-Model.md index 9cc51732f7..8f51213211 100644 --- a/docs/source/Megatron-SWIFT/Multimodal-Model.md +++ b/docs/source/Megatron-SWIFT/Multimodal-Model.md @@ -1,6 +1,6 @@ # 多模态模型 -ms-swift引入了Megatron的并行技术来加速多模态大模型的训练。目前支持Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL等模型的CPT/SFT/DPO/KTO/RM。完整支持的模型可以参考[支持的模型与数据集文档](../Instruction/Supported-models-and-datasets.md)。 +ms-swift引入了Megatron的并行技术来加速多模态大模型的训练。目前支持Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL等模型的CPT/SFT/GRPO/DPO/KTO/RM。完整支持的模型可以参考[支持的模型与数据集文档](../Instruction/Supported-models-and-datasets.md)。 环境准备请参考Megatron-SWIFT的[快速开始文档](./Quick-start.md)。 diff --git a/docs/source/Megatron-SWIFT/Quick-start.md b/docs/source/Megatron-SWIFT/Quick-start.md index 8c92e2b6b9..faff26ecec 100644 --- a/docs/source/Megatron-SWIFT/Quick-start.md +++ b/docs/source/Megatron-SWIFT/Quick-start.md @@ -8,6 +8,7 @@ ms-swift引入了Megatron的并行技术来加速大模型的训练,包括数 | ------ | ------ | ---- | ----- | ----- | | 预训练| ✅ | ✅| ✅ | ✅ | | 指令监督微调 | ✅ | ✅| ✅ | ✅ | +| GRPO | ✅ | ✅| ✅ | ✅ | | DPO | ✅ | ✅| ✅ | ✅ | | KTO | ✅ | ✅| ✅ | ✅ | | RM | ✅ | ✅| ✅ | ✅ | diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index bb632247e1..72e46f70f3 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -269,7 +269,7 @@ LoRA Training: - model: The model_id or model_path of safetensors weights. Defaults to None. - model_type: Model type. For details, refer to [ms-swift command-line parameters documentation](../Instruction/Command-line-parameters.md). - adapters: adapter_id or adapter_path of LoRA incremental weights in safetensors format. Default is `[]`. -- ref_model: model_id or model_path of ref_model safetensors weights. Required when using DPO or KTO algorithms with full-parameter training. Default is None, set to `--model`. +- ref_model: model_id or model_path of ref_model safetensors weights. Required when using DPO/GRPO/KTO algorithms with full-parameter training. Default is None, set to `--model`. - ref_adapters: List of adapter_id or adapter_path of ref_adapters safetensors weights (currently only supports length of 1). Default is `[]`. - use_hf: Controls whether to use ModelScope or HuggingFace for model download, dataset download, and model push. Default is False, using ModelScope. - hub_token: Hub token. ModelScope hub token can be found [here](https://modelscope.cn/my/myaccesstoken). Default is None. @@ -291,7 +291,7 @@ Megatron training parameters are inherited from Megatron parameters and basic pa - Typically used together with `--freeze_vit false` and `--freeze_aligner false`. - aligner_lr: Specifies the learning rate for the aligner module in multimodal models. Default is `None`, same as `learning_rate`. - gradient_checkpointing_kwargs: Arguments passed to `torch.utils.checkpoint`. For example: set `--gradient_checkpointing_kwargs '{"use_reentrant": false}'`. Defaults to `None`. This parameter only takes effect when `vit_gradient_checkpointing` is enabled. -- 🔥packing: Whether to use sequence packing to improve computational efficiency (achieving better load balancing across nodes and processes, and higher GPU utilization), at the cost of additional preprocessing time, while also stabilizing GPU memory usage. Defaults to `False`. Currently supported for CPT, SFT, DPO, KTO and RM. +- 🔥packing: Whether to use sequence packing to improve computational efficiency (achieving better load balancing across nodes and processes, and higher GPU utilization), at the cost of additional preprocessing time, while also stabilizing GPU memory usage. Defaults to `False`. Currently supported for CPT, SFT, GRPO, DPO, KTO and RM. - Note: **Sequences within the same batch remain mutually invisible**, except for Qwen3-Next. - Note: **Packing will reduce the number of dataset samples. Please adjust global_batch_size and learning rate accordingly**. - packing_length: the length to use for packing. Defaults to None, in which case it is set to max_length. @@ -315,7 +315,7 @@ Megatron training parameters are inherited from Megatron parameters and basic pa In addition to inheriting the training parameters, the following parameters are also supported: -- 🔥rlhf_type: Default is 'dpo'. Currently, 'dpo', 'kto', and 'rm' are available. +- 🔥rlhf_type: Default is 'dpo'. Currently, 'dpo', 'grpo', 'kto', and 'rm' are available. - loss_scale: Overrides the `loss_scale` in [basic parameters](../Instruction/Command-line-parameters.md). Default is 'last_round'. - calculate_per_token_loss: Overrides the Megatron parameter. Default is False. diff --git a/docs/source_en/Megatron-SWIFT/Multimodal-Model.md b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md index 9f339cc547..d3d96dde1f 100644 --- a/docs/source_en/Megatron-SWIFT/Multimodal-Model.md +++ b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md @@ -1,6 +1,6 @@ # Multimodal Models -ms-swift introduces Megatron's parallelization techniques to accelerate the training of large multimodal models. Currently, it supports CPT/SFT/DPO/KTO/RM for models such as Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL. For a complete list of supported models, please refer to the [Supported Models and Datasets documentation](../Instruction/Supported-models-and-datasets.md). +ms-swift introduces Megatron's parallelization techniques to accelerate the training of large multimodal models. Currently, it supports CPT/SFT/GRPO/DPO/KTO/RM for models such as Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL. For a complete list of supported models, please refer to the [Supported Models and Datasets documentation](../Instruction/Supported-models-and-datasets.md). For environment setup, please refer to the Megatron-SWIFT [Quick Start guide](./Quick-start.md). diff --git a/docs/source_en/Megatron-SWIFT/Quick-start.md b/docs/source_en/Megatron-SWIFT/Quick-start.md index ed46f0471f..94123c8c4e 100644 --- a/docs/source_en/Megatron-SWIFT/Quick-start.md +++ b/docs/source_en/Megatron-SWIFT/Quick-start.md @@ -7,9 +7,10 @@ ms-swift incorporates Megatron's parallelization techniques to accelerate the tr | ---------------------------------- | -------------- | ---- | ---- | ---------- | | Pretraining | ✅ | ✅ | ✅ | ✅ | | Instruction-supervised fine-tuning | ✅ | ✅ | ✅ | ✅ | +| GRPO | ✅ | ✅ | ✅ | ✅ | | DPO | ✅ | ✅ | ✅ | ✅ | | KTO | ✅ | ✅ | ✅ | ✅ | -| RM | ✅ | ✅ | ✅ | ✅ | +| RM | ✅ | ✅ | ✅ | ✅ | | Classification tasks | ✅ | ✅ | ✅ | ✅ | ## Environment Setup From 5977fe5f1924b68d1bf0f9cc079a618a94f87038 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 14 Nov 2025 17:10:43 +0800 Subject: [PATCH 79/83] sleep level doc --- docs/source/Instruction/Command-line-parameters.md | 2 +- docs/source_en/Instruction/Command-line-parameters.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/Instruction/Command-line-parameters.md b/docs/source/Instruction/Command-line-parameters.md index 40f9fda195..37fce6c67e 100644 --- a/docs/source/Instruction/Command-line-parameters.md +++ b/docs/source/Instruction/Command-line-parameters.md @@ -581,7 +581,7 @@ reward模型参数将在PPO、GRPO中使用。 - vllm_enable_prefix_caching: vllm透传参数,默认为True。 - vllm_tensor_parallel_size: tp并行数,默认为`1`。 - vllm_enable_lora: 支持vLLM Engine 加载 LoRA adapter,默认为False。用于加速LoRA训练的权重同步,具体参考[文档](./GRPO/GetStarted/GRPO.md#权重同步加速)。 - - sleep_level: 训练时释放 vLLM 显存,可选项为[0, 1], 默认为0,不释放。 + - sleep_level: 训练时释放 vLLM 显存,可选项为[0, 1, 2], 默认为0,不释放。 - offload_optimizer: 是否在vLLM推理时offload optimizer参数,默认为False。 - offload_model: 是否在vLLM推理时 offload 模型,默认为False。 - completion_length_limit_scope: 在多轮对话中,`max_completion_length` 的限制范围。 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 6ce0658114..d0832bbf20 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -592,7 +592,7 @@ The meanings of the following parameters can be referenced [here](https://huggin - vllm_enable_prefix_caching: A pass-through parameter for vLLM, default is True. - vllm_tensor_parallel_size: the tensor parallel size of vLLM engine, default is 1. - vllm_enable_lora: Enable the vLLM engine to load LoRA adapters; defaults to False. Used to accelerate weight synchronization during LoRA training. See the [documentation](./GRPO/GetStarted/GRPO.md#weight-sync-acceleration) for details. - - sleep_level: make vllm sleep when model is training. Options are 0 or 1, default is 0, no sleep + - sleep_level: make vllm sleep when model is training. Options are 0/1/2, default is 0, no sleep - offload_optimizer: Whether to offload optimizer parameters during inference with vLLM. The default is `False`. - offload_model: Whether to offload the model during inference with vLLM. The default is `False`. - completion_length_limit_scope: Specifies the scope of the `max_completion_length` limit in multi-turn conversations. From 5ab6d371b0d3014101d2c4d0544d34090042fb7a Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 14 Nov 2025 17:25:38 +0800 Subject: [PATCH 80/83] fix kl metrics --- swift/megatron/trainers/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 53269a8483..d3253d4b39 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -1129,7 +1129,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): if self.beta != 0.0: # Unified processing (no CP-specific logic needed) kl_value = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) - avg_metric['kl'] = kl_value.item() + avg_metric['kl'] = kl_value.clone().detach() mode = 'train' if self.unwrapped_models[0].training else 'eval' if self._metrics[mode]: From 2841fb9a3c2f5282fc52431be4ef34ff0bf94dc5 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 14 Nov 2025 17:43:14 +0800 Subject: [PATCH 81/83] fix arxiv link & fix kl metric --- docs/source/Instruction/Command-line-parameters.md | 11 +++++------ docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md | 2 +- docs/source/Megatron-SWIFT/Command-line-parameters.md | 10 +++++----- docs/source_en/Instruction/Command-line-parameters.md | 2 +- .../Instruction/GRPO/AdvancedResearch/GSPO.md | 2 +- .../Megatron-SWIFT/Command-line-parameters.md | 2 +- examples/megatron/grpo/dense_colocate.sh | 1 - examples/megatron/grpo/dense_server.sh | 4 ++-- swift/megatron/argument/megatron_args.py | 2 +- swift/trainers/arguments.py | 2 +- swift/trainers/rlhf_trainer/grpo_trainer.py | 2 +- 11 files changed, 19 insertions(+), 21 deletions(-) mode change 100644 => 100755 examples/megatron/grpo/dense_server.sh diff --git a/docs/source/Instruction/Command-line-parameters.md b/docs/source/Instruction/Command-line-parameters.md index 37fce6c67e..1d15a21492 100644 --- a/docs/source/Instruction/Command-line-parameters.md +++ b/docs/source/Instruction/Command-line-parameters.md @@ -566,13 +566,12 @@ reward模型参数将在PPO、GRPO中使用。 - use_vllm: 是否使用 vLLM 作为 GRPO 生成的 infer_backend,默认为False。 - vllm_mode: vLLM 集成模式,可选项为 `server` 和 `colocate`。server 模式使用 `swift rollout` 拉起的 vLLM 服务器进行采样,colocate 模式在程序内部署 vLLM。使用server端时, - vllm_mode server 参数 - - vllm_server_base_url: vLLM server的Base URL(比如 http://local_host:8000), 默认为None。设置后,忽略host和port设置。 - - vllm_server_host:vLLM server host地址,默认为None。 - - vllm_server_port vLLM server 服务端口,默认为8000。 - - vllm_server_timeout 连接vLLM server的超时时间,默认为 240s。 + - vllm_server_host: vLLM server host地址,默认为None。 + - vllm_server_port: vLLM server 服务端口,默认为8000。 + - vllm_server_timeout: 连接vLLM server的超时时间,默认为 240s。 - vllm_server_pass_dataset: 透传额外的数据集信息到vLLM server,用于多轮训练。 - async_generate: 异步rollout以提高训练速度,注意开启时采样会使用上一轮更新的模型进行采样,不支持多轮场景。默认`false`. - - SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE:环境变量,用于控制权重同步时的传输桶大小(bucket size),适用于 Server Mode 下的全参数训练,单位为 MB,默认值为 512 MB。 + - SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE: 环境变量,用于控制权重同步时的传输桶大小(bucket size),适用于 Server Mode 下的全参数训练,单位为 MB,默认值为 512 MB。 - vllm_mode colocate 参数(更多参数支持参考[vLLM参数](#vLLM参数)。) - vllm_gpu_memory_utilization: vllm透传参数,默认为0.9。 - vllm_max_model_len: vllm透传参数,默认为None。 @@ -593,7 +592,7 @@ reward模型参数将在PPO、GRPO中使用。 - max_resample_times:dynamic_sample设置下限制重采样次数,默认3次。 - overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。 - delta: [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291)中双侧 GRPO 上界裁剪值。若设置,建议大于 1 + epsilon。默认为None。 -- importance_sampling_level: 控制重要性采样比计算,可选项为 `token` 和 `sequence`,`token` 模式下保留原始的每个 token 的对数概率比,`sequence` 模式下则会对序列中所有有效 token 的对数概率比进行平均。[GSPO论文](https://www.arxiv.org/abs/2507.18071)中使用sequence级别计算来稳定训练,默认为`token`。 +- importance_sampling_level: 控制重要性采样比计算,可选项为 `token` 和 `sequence`,`token` 模式下保留原始的每个 token 的对数概率比,`sequence` 模式下则会对序列中所有有效 token 的对数概率比进行平均。[GSPO论文](https://arxiv.org/abs/2507.18071)中使用sequence级别计算来稳定训练,默认为`token`。 - advantage_estimator: 优势计算函数,默认为 `grpo`,即计算组内相对优势,可选项为 `grpo`、[`rloo`](./GRPO/AdvancedResearch/RLOO.md)、[`reinforce_plus_plus`](./GRPO/AdvancedResearch/REINFORCEPP.md)。 - kl_in_reward: 控制 KL 散度正则项的处理位置;`false`表示作为损失函数的独立正则项,`true`表示将 KL 直接并入奖励(从奖励中扣除)。默认情况与advantage_estimator绑定,`grpo`下默认为`false`,`rloo` 和 `reinforce_plus_plus` 下默认为 `true`。 - scale_rewards:指定奖励的缩放策略。可选值包括 `group`(按组内标准差缩放)、`batch`(按整个批次的标准差缩放)、`none`(不进行缩放)。在 ms-swift < 3.10 版本中,该参数为布尔类型,`true` 对应 `group`,`false` 对应 `none`。默认值与 `advantage_estimator` 绑定:`grpo` 对应 `group`,`rloo` 对应 `none`,`reinforce_plus_plus` 对应 `batch`。 diff --git a/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md b/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md index 6dc03118e2..9bc9df2f80 100644 --- a/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md +++ b/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md @@ -2,7 +2,7 @@ **版本依赖**:ms-swift>=3.7 -[Group Sequence Policy Optimization](https://www.arxiv.org/abs/2507.18071)中指出GRPO在计算重要性采样权重时,是在token级别进行操作的。然而,这种做法由于每个token仅采样一次,无法实现有效的分布校正,反而会在模型训练过程中引入高方差噪声,极易导致模型的梯度估计不稳定,最终造成模型训练的崩塌。因此,论文认为,优化目标的单位应该与奖励的单位保持一致。由于奖励通常是在序列级别(即完整生成的回复)给出的,因此更合理的做法是将 off-policy 校正和优化也提升到序列级别,而非 token 级别。以下是三种计算策略对比: +[Group Sequence Policy Optimization](https://arxiv.org/abs/2507.18071)中指出GRPO在计算重要性采样权重时,是在token级别进行操作的。然而,这种做法由于每个token仅采样一次,无法实现有效的分布校正,反而会在模型训练过程中引入高方差噪声,极易导致模型的梯度估计不稳定,最终造成模型训练的崩塌。因此,论文认为,优化目标的单位应该与奖励的单位保持一致。由于奖励通常是在序列级别(即完整生成的回复)给出的,因此更合理的做法是将 off-policy 校正和优化也提升到序列级别,而非 token 级别。以下是三种计算策略对比: 1. GRPO 对每个 token 独立计算重要性采样比,具体公式为 diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index 82c05590ac..36ab4c6809 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -339,12 +339,12 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用 - vllm_mode: vLLM 集成模式,可选项为 `server` 和 `colocate`。server 模式使用 `swift rollout` 拉起的 vLLM 服务器进行采样,colocate 模式在程序内部署 vLLM。使用server端时, - vllm_mode server 参数 - vllm_server_base_url: vLLM server的Base URL(比如 http://local_host:8000), 默认为None。设置后,忽略host和port设置。 - - vllm_server_host:vLLM server host地址,默认为None。 - - vllm_server_port vLLM server 服务端口,默认为8000。 - - vllm_server_timeout 连接vLLM server的超时时间,默认为 240s。 + - vllm_server_host: vLLM server host地址,默认为None。 + - vllm_server_port: vLLM server 服务端口,默认为8000。 + - vllm_server_timeout: 连接vLLM server的超时时间,默认为 240s。 - vllm_server_pass_dataset: 透传额外的数据集信息到vLLM server,用于多轮训练。 - async_generate: 异步rollout以提高训练速度,注意开启时采样会使用上一轮更新的模型进行采样,不支持多轮场景。默认`false`. - - SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE:环境变量,用于控制权重同步时的传输桶大小(bucket size),适用于 Server Mode 下的全参数训练,单位为 MB,默认值为 512 MB。 + - SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE: 环境变量,用于控制权重同步时的传输桶大小(bucket size),适用于 Server Mode 下的全参数训练,单位为 MB,默认值为 512 MB。 - vllm_mode colocate 参数(更多参数支持参考[vLLM参数](#vLLM参数)。) - vllm_gpu_memory_utilization: vllm透传参数,默认为0.9。 - vllm_max_model_len: vllm透传参数,默认为None。 @@ -363,7 +363,7 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用 - max_resample_times:dynamic_sample设置下限制重采样次数,默认3次。 - overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。 - delta: [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291)中双侧 GRPO 上界裁剪值。若设置,建议大于 1 + epsilon。默认为None。 -- importance_sampling_level: 控制重要性采样比计算,可选项为 `token` 和 `sequence`,`token` 模式下保留原始的每个 token 的对数概率比,`sequence` 模式下则会对序列中所有有效 token 的对数概率比进行平均。[GSPO论文](https://www.arxiv.org/abs/2507.18071)中使用sequence级别计算来稳定训练,默认为`token`。 +- importance_sampling_level: 控制重要性采样比计算,可选项为 `token` 和 `sequence`,`token` 模式下保留原始的每个 token 的对数概率比,`sequence` 模式下则会对序列中所有有效 token 的对数概率比进行平均。[GSPO论文](https://arxiv.org/abs/2507.18071)中使用sequence级别计算来稳定训练,默认为`token`。 - scale_rewards:指定奖励的缩放策略。可选值包括 `group`(按组内标准差缩放)、`batch`(按整个批次的标准差缩放)、`none`(不进行缩放)。在 ms-swift < 3.10 版本中,该参数为布尔类型,`true` 对应 `group`,`false` 对应 `none`。默认值与 `advantage_estimator` 绑定:`grpo` 对应 `group`,`rloo` 对应 `none`,`reinforce_plus_plus` 对应 `batch`。 内置奖励函数参数参考[文档](../Instruction/Command-line-parameters.md#奖励函数参数) diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index d0832bbf20..d4f704c7d5 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -607,7 +607,7 @@ The meanings of the following parameters can be referenced [here](https://huggin - overlong_filter: Skip overlong truncated samples, which will not be included in loss calculation. Default is False. The hyperparameters for the reward function can be found in the [Built-in Reward Functions section](#built-in-reward-functions). - delta: Delta value for the upper clipping bound in two-sided GRPO. Recommended to be > 1 + epsilon. This method was introduced in the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291). -- importance_sampling_level: Controls how the importance sampling ratio is computed. Options are `token` and `sequence`. In `token` mode, the raw per-token log-probability ratios are used. In `sequence` mode, the log-probability ratios of all valid tokens in the sequence are averaged to produce a single ratio per sequence. The [GSPO paper](https://www.arxiv.org/abs/2507.18071) uses sequence-level importance sampling to stabilize training. The default is `token`. +- importance_sampling_level: Controls how the importance sampling ratio is computed. Options are `token` and `sequence`. In `token` mode, the raw per-token log-probability ratios are used. In `sequence` mode, the log-probability ratios of all valid tokens in the sequence are averaged to produce a single ratio per sequence. The [GSPO paper](https://arxiv.org/abs/2507.18071) uses sequence-level importance sampling to stabilize training. The default is `token`. - advantage_estimator: Advantage estimator. Default is `grpo` (group-relative advantage). Options: `grpo`, [`rloo`](./GRPO/AdvancedResearch/RLOO.md), [`reinforce_plus_plus`](./GRPO/AdvancedResearch/REINFORCEPP.md). - kl_in_reward: Controls where the KL regularization is applied. `false`: KL is a separate loss term. `true`: KL is subtracted from the reward. The default is bound to `advantage_estimator`: `false` for `grpo`, and `true` for `rloo` and `reinforce_plus_plus`. - scale_rewards: Specifies the reward scaling strategy. Options: `group` (scale by intra-group std), `batch` (scale by batch-wide std), `none` (no scaling). In ms-swift < 3.10, this was a boolean where `true` corresponds to `group` and `false` to `none`. The default is bound to `advantage_estimator`: `group` for `grpo`, `none` for `rloo`, and `batch` for `reinforce_plus_plus`. diff --git a/docs/source_en/Instruction/GRPO/AdvancedResearch/GSPO.md b/docs/source_en/Instruction/GRPO/AdvancedResearch/GSPO.md index 5062188bf2..03c67b3c6e 100644 --- a/docs/source_en/Instruction/GRPO/AdvancedResearch/GSPO.md +++ b/docs/source_en/Instruction/GRPO/AdvancedResearch/GSPO.md @@ -2,7 +2,7 @@ **Version Requirement**: ms-swift>=3.8 -In [Group Sequence Policy Optimization](https://www.arxiv.org/abs/2507.18071), it is pointed out that GRPO computes importance sampling weights at the token level. However, this approach is problematic: since each token is only sampled once, it cannot realize effective distribution correction, and instead introduces high-variance noise during training, which can easily lead to unstable gradient estimates and even training collapse. Therefore, the paper argues that the unit of the objective function should be consistent with that of the reward. Since the reward is typically given at the sequence level (i.e., for the entire generated response), it is more reasonable to perform off-policy correction and optimization at the sequence level rather than the token level. +In [Group Sequence Policy Optimization](https://arxiv.org/abs/2507.18071), it is pointed out that GRPO computes importance sampling weights at the token level. However, this approach is problematic: since each token is only sampled once, it cannot realize effective distribution correction, and instead introduces high-variance noise during training, which can easily lead to unstable gradient estimates and even training collapse. Therefore, the paper argues that the unit of the objective function should be consistent with that of the reward. Since the reward is typically given at the sequence level (i.e., for the entire generated response), it is more reasonable to perform off-policy correction and optimization at the sequence level rather than the token level. Below are the three main strategies for computing importance sampling weights: diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 72e46f70f3..6f387d89d6 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -387,7 +387,7 @@ In addition to inheriting the training parameters, the following parameters are - max_resample_times: Limit the number of resampling times under dynamic_sample setting. Default is 3. - overlong_filter: Skip overlong truncated samples, which do not participate in loss calculation. Default is False. - delta: Bilateral GRPO upper bound clipping value from the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291). If set, it is recommended to be greater than 1 + epsilon. Default is None. -- importance_sampling_level: Controls importance sampling ratio calculation. Options are `token` and `sequence`. In `token` mode, the original log probability ratio for each token is preserved. In `sequence` mode, the log probability ratios of all valid tokens in the sequence are averaged. The [GSPO paper](https://www.arxiv.org/abs/2507.18071) uses sequence-level calculation to stabilize training. Default is `token`. +- importance_sampling_level: Controls importance sampling ratio calculation. Options are `token` and `sequence`. In `token` mode, the original log probability ratio for each token is preserved. In `sequence` mode, the log probability ratios of all valid tokens in the sequence are averaged. The [GSPO paper](https://arxiv.org/abs/2507.18071) uses sequence-level calculation to stabilize training. Default is `token`. - scale_rewards: Specifies the reward scaling strategy. Options include `group` (scale by within-group standard deviation), `batch` (scale by batch-wide standard deviation), and `none` (no scaling). In ms-swift < 3.10, this parameter is boolean, where `true` corresponds to `group` and `false` corresponds to `none`. The default value is bound to `advantage_estimator`: `grpo` corresponds to `group`, `rloo` corresponds to `none`, and `reinforce_plus_plus` corresponds to `batch`. Built-in reward function parameters refer to the [documentation](../Instruction/Command-line-parameters.md#reward-function-parameters). diff --git a/examples/megatron/grpo/dense_colocate.sh b/examples/megatron/grpo/dense_colocate.sh index a7b79f2bd8..4cbd7cafbb 100644 --- a/examples/megatron/grpo/dense_colocate.sh +++ b/examples/megatron/grpo/dense_colocate.sh @@ -48,7 +48,6 @@ megatron rlhf \ --offload_optimizer true \ --log_interval 1 \ --recompute_granularity selective \ - --max_epochs 1 \ --finetune \ --num_workers 8 \ --dataset_num_proc 8 \ diff --git a/examples/megatron/grpo/dense_server.sh b/examples/megatron/grpo/dense_server.sh old mode 100644 new mode 100755 index 7361d2dcd0..9ff30766cb --- a/examples/megatron/grpo/dense_server.sh +++ b/examples/megatron/grpo/dense_server.sh @@ -55,7 +55,6 @@ megatron rlhf \ --offload_optimizer true \ --log_interval 1 \ --recompute_granularity selective \ - --max_epochs 1 \ --finetune \ --num_workers 8 \ --dataset_num_proc 8 \ @@ -70,4 +69,5 @@ megatron rlhf \ --wandb_exp_name megatron_grpo \ --train_iters 100 \ --eval_interval 1000 \ - --save_interval 1000 + --save_interval 1000 \ + --scale_rewards group diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index e0ba114de2..4609df7948 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -41,7 +41,7 @@ class RLHFMegatronArgumentsMixin: steps_per_generation: Optional[int] = None num_generations: int = 8 max_completion_length: int = 512 - # GSPO https://www.arxiv.org/abs/2507.18071 + # GSPO https://arxiv.org/abs/2507.18071 importance_sampling_level: Literal['token', 'sequence', 'sequence_token'] = 'token' epsilon: float = 0.2 diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index 968f5cedf5..42f6afdcdd 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -324,7 +324,7 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin): # Beyond the 80/20 Rule, https://arxiv.org/abs/2506.01939 top_entropy_quantile: float = 1.0 - # GSPO https://www.arxiv.org/abs/2507.18071 + # GSPO https://arxiv.org/abs/2507.18071 importance_sampling_level: Literal['token', 'sequence', 'sequence_token'] = 'token' # RLOO, REINFORCE++ diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 1b81ffaeb0..53cc4b5c99 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -1861,7 +1861,7 @@ def _prepare_algorithm_params(self): # Entropy Mask, https://arxiv.org/abs/2506.01939 self.top_entropy_quantile = args.top_entropy_quantile - # GSPO, https://www.arxiv.org/abs/2507.18071 + # GSPO, https://arxiv.org/abs/2507.18071 self.importance_sampling_level = args.importance_sampling_level # RLOO, From 2a97c64021a8e0080b7dd24ed5d3d9346e100ade Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 14 Nov 2025 17:44:02 +0800 Subject: [PATCH 82/83] revert script --- examples/megatron/grpo/dense_server.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) mode change 100755 => 100644 examples/megatron/grpo/dense_server.sh diff --git a/examples/megatron/grpo/dense_server.sh b/examples/megatron/grpo/dense_server.sh old mode 100755 new mode 100644 index 9ff30766cb..ee702800e2 --- a/examples/megatron/grpo/dense_server.sh +++ b/examples/megatron/grpo/dense_server.sh @@ -69,5 +69,4 @@ megatron rlhf \ --wandb_exp_name megatron_grpo \ --train_iters 100 \ --eval_interval 1000 \ - --save_interval 1000 \ - --scale_rewards group + --save_interval 1000 From 40706d7b10a3933ac1400a1caab6f6c544618a0e Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 14 Nov 2025 18:13:34 +0800 Subject: [PATCH 83/83] revert server_base_url doc --- docs/source/Instruction/Command-line-parameters.md | 1 + docs/source/Megatron-SWIFT/Command-line-parameters.md | 2 +- docs/source_en/Instruction/Command-line-parameters.md | 2 +- docs/source_en/Megatron-SWIFT/Command-line-parameters.md | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source/Instruction/Command-line-parameters.md b/docs/source/Instruction/Command-line-parameters.md index 1d15a21492..2a2f9db2d4 100644 --- a/docs/source/Instruction/Command-line-parameters.md +++ b/docs/source/Instruction/Command-line-parameters.md @@ -568,6 +568,7 @@ reward模型参数将在PPO、GRPO中使用。 - vllm_mode server 参数 - vllm_server_host: vLLM server host地址,默认为None。 - vllm_server_port: vLLM server 服务端口,默认为8000。 + - vllm_server_base_url: vLLM server的Base URL(比如 http://local_host:8000), 默认为None。设置后,忽略host和port设置。 - vllm_server_timeout: 连接vLLM server的超时时间,默认为 240s。 - vllm_server_pass_dataset: 透传额外的数据集信息到vLLM server,用于多轮训练。 - async_generate: 异步rollout以提高训练速度,注意开启时采样会使用上一轮更新的模型进行采样,不支持多轮场景。默认`false`. diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index 36ab4c6809..5c75aa28c0 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -338,9 +338,9 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用 - log_completions: 是否记录训练中的模型生成内容,默认为False。 - vllm_mode: vLLM 集成模式,可选项为 `server` 和 `colocate`。server 模式使用 `swift rollout` 拉起的 vLLM 服务器进行采样,colocate 模式在程序内部署 vLLM。使用server端时, - vllm_mode server 参数 - - vllm_server_base_url: vLLM server的Base URL(比如 http://local_host:8000), 默认为None。设置后,忽略host和port设置。 - vllm_server_host: vLLM server host地址,默认为None。 - vllm_server_port: vLLM server 服务端口,默认为8000。 + - vllm_server_base_url: vLLM server的Base URL(比如 http://local_host:8000), 默认为None。设置后,忽略host和port设置。 - vllm_server_timeout: 连接vLLM server的超时时间,默认为 240s。 - vllm_server_pass_dataset: 透传额外的数据集信息到vLLM server,用于多轮训练。 - async_generate: 异步rollout以提高训练速度,注意开启时采样会使用上一轮更新的模型进行采样,不支持多轮场景。默认`false`. diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index d4f704c7d5..c7dd3d865a 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -577,9 +577,9 @@ The meanings of the following parameters can be referenced [here](https://huggin - use_vllm: Whether to use vLLM as the infer_backend for GRPO generation, default is False. - vllm_mode: Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `server` or `colocate` - vllm_mode server parameter - - vllm_server_base_url: Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` " "and `vllm_server_port` are ignored. Default is None. - vllm_server_host: The host address of the vLLM server. Default is None. - vllm_server_port: The service port of the vLLM server. Default is 8000. + - vllm_server_base_url: Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` " "and `vllm_server_port` are ignored. Default is None. - vllm_server_timeout: The connection timeout for the vLLM server. Default is 240 seconds. - vllm_server_pass_dataset: pass additional dataset information through to the vLLM server for multi-turn training. - async_generate: Use async rollout to improve train speed. Note that rollout will use the model updated in the previous round when enabled. Multi-turn scenarios are not supported. Default is `false`. diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 6f387d89d6..446916c1f7 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -362,9 +362,9 @@ In addition to inheriting the training parameters, the following parameters are - log_completions: Whether to log model-generated content during training. Default is False. - vllm_mode: vLLM integration mode. Options are `server` and `colocate`. Server mode uses the vLLM server launched by `swift rollout` for sampling, while colocate mode deploys vLLM within the program. When using server mode: - vllm_mode server parameters: - - vllm_server_base_url: Base URL of the vLLM server (e.g., http://local_host:8000). Default is None. When set, host and port settings are ignored. - vllm_server_host: vLLM server host address. Default is None. - vllm_server_port: vLLM server port. Default is 8000. + - vllm_server_base_url: Base URL of the vLLM server (e.g., http://local_host:8000). Default is None. When set, host and port settings are ignored. - vllm_server_timeout: Timeout for connecting to the vLLM server. Default is 240s. - vllm_server_pass_dataset: Pass additional dataset information to the vLLM server for multi-round training. - async_generate: Asynchronous rollout to improve training speed. Note: When enabled, sampling uses the model from the previous round update, and multi-round scenarios are not supported. Default is `false`.