diff --git a/swift/trainers/callback.py b/swift/trainers/callback.py index be299a8fe9..d77464e43c 100644 --- a/swift/trainers/callback.py +++ b/swift/trainers/callback.py @@ -44,9 +44,11 @@ def add_train_message(logs, state, start_time) -> None: class ProgressCallbackNew(ProgressCallback): def on_train_begin(self, args, state, control, **kwargs): + initial_step = state.global_step or 0 if state.is_world_process_zero: - self.training_bar = tqdm(desc='Train', total=state.max_steps, dynamic_ncols=True) - self.current_step = 0 + bar_initial = min(initial_step, state.max_steps) if state.max_steps else initial_step + self.training_bar = tqdm(desc='Train', total=state.max_steps, initial=bar_initial, dynamic_ncols=True) + self.current_step = initial_step self.start_time = time.time() def on_prediction_step(self, args, state: TrainerState, control, eval_dataloader=None, **kwargs):