Skip to content

Commit c75a637

Browse files
authored
"[QEff. Finetune]: Support for resuming checkpoints using Epoch" (#614)
Signed-off-by: Tanisha <tchawada@qti.qualcomm.com>
1 parent 04f1ad7 commit c75a637

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

QEfficient/finetune/utils/train_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,12 @@ def train(
123123
break
124124

125125
if train_config.use_peft and train_config.from_peft_checkpoint:
126+
path = train_config.from_peft_checkpoint.rstrip("/")
126127
try:
127-
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
128-
intermediate_step = int(train_config.from_peft_checkpoint.split("/")[-1].split("_")[-1])
128+
intermediate_epoch = int(path.split("/")[-2].split("_")[-1]) - 1
129+
intermediate_step = int(path.split("/")[-1].split("_")[-1])
129130
except (IndexError, ValueError):
130-
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-1].split("_")[-1]) - 1
131+
intermediate_epoch = int(path.split("/")[-1].split("_")[-1]) - 1
131132
intermediate_step = 0
132133

133134
if epoch < intermediate_epoch:
@@ -374,7 +375,7 @@ def train(
374375
eval_step_metric,
375376
eval_metric,
376377
)
377-
avg_epoch_time = sum(epoch_times) / len(epoch_times)
378+
avg_epoch_time = sum(epoch_times) / len(epoch_times) if len(epoch_times) > 0 else 0
378379
avg_checkpoint_time = sum(checkpoint_times) / len(checkpoint_times) if len(checkpoint_times) > 0 else 0
379380

380381
results["last_epoch_train_loss"] = train_epoch_loss.cpu()

0 commit comments

Comments
 (0)