@@ -124,10 +124,9 @@ def train(
124124
125125 if train_config .use_peft and train_config .from_peft_checkpoint :
126126 intermediate_epoch = int (train_config .from_peft_checkpoint .split ("/" )[- 2 ].split ("_" )[- 1 ]) - 1
127+ intermediate_step = int (train_config .from_peft_checkpoint .split ("/" )[- 1 ].split ("_" )[- 1 ])
127128 if epoch < intermediate_epoch :
128129 logger .log_rank_zero (f"Skipping epoch { epoch + 1 } since fine tuning has already completed for it." )
129- # to bring the count of train_step in sync with where it left off
130- total_train_steps += len (train_dataloader )
131130 continue
132131
133132 logger .log_rank_zero (f"Starting epoch { epoch + 1 } /{ train_config .num_epochs } " )
@@ -149,20 +148,18 @@ def train(
149148
150149 num_dummy_samples = 0
151150 for step , batch in enumerate (train_dataloader ):
151+ # total_train_steps indicates the cumulative number of training steps completed across all epochs.
152+ # When resuming fine-tuning from previously saved checkpoints, total_train_steps indicates the total number of steps trained across the earlier session and the ongoing one.
153+ total_train_steps = (epoch ) * len (train_dataloader ) + step
152154 # resume training from a particular checkpoint, assuming the dataset is not shuffled
153155 if train_config .use_peft and train_config .from_peft_checkpoint :
154- intermediate_step = int (train_config .from_peft_checkpoint .split ("/" )[- 1 ].split ("_" )[- 1 ])
155- intermediate_epoch = int (train_config .from_peft_checkpoint .split ("/" )[- 2 ].split ("_" )[- 1 ]) - 1
156156 # to bring the count of train_step in sync with where it left off
157157 if epoch == intermediate_epoch and step == 0 :
158- total_train_steps += intermediate_step
159158 logger .log_rank_zero (
160159 f"Skipping first { intermediate_step } steps for epoch { epoch + 1 } , since fine tuning has already completed for it."
161160 )
162161 if epoch == intermediate_epoch and step < intermediate_step :
163- total_train_steps += 1
164162 continue
165- total_train_steps += 1
166163
167164 if train_config .max_train_step > 0 and total_train_steps >= train_config .max_train_step :
168165 max_steps_reached = True
@@ -235,12 +232,12 @@ def train(
235232 else :
236233 num_samples_in_cur_update = len (train_dataloader ) % train_config .gradient_accumulation_steps
237234
238- loss = loss / num_samples_in_cur_update
235+ normalized_loss = loss / num_samples_in_cur_update
239236
240237 if train_config .grad_scaler :
241- scaler .scale (loss ).backward () # backward pass
238+ scaler .scale (normalized_loss ).backward () # backward pass
242239 else :
243- loss .backward () # backward pass
240+ normalized_loss .backward () # backward pass
244241
245242 if is_optimizer_step :
246243 if train_config .grad_scaler :
0 commit comments