Skip to content
9 changes: 5 additions & 4 deletions QEfficient/finetune/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,12 @@ def train(
break

if train_config.use_peft and train_config.from_peft_checkpoint:
path = train_config.from_peft_checkpoint.rstrip("/")
try:
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
intermediate_step = int(train_config.from_peft_checkpoint.split("/")[-1].split("_")[-1])
intermediate_epoch = int(path.split("/")[-2].split("_")[-1]) - 1
intermediate_step = int(path.split("/")[-1].split("_")[-1])
except (IndexError, ValueError):
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-1].split("_")[-1]) - 1
intermediate_epoch = int(path.split("/")[-1].split("_")[-1]) - 1
intermediate_step = 0

if epoch < intermediate_epoch:
Expand Down Expand Up @@ -374,7 +375,7 @@ def train(
eval_step_metric,
eval_metric,
)
avg_epoch_time = sum(epoch_times) / len(epoch_times)
avg_epoch_time = sum(epoch_times) / len(epoch_times) if len(epoch_times) > 0 else 0
avg_checkpoint_time = sum(checkpoint_times) / len(checkpoint_times) if len(checkpoint_times) > 0 else 0

results["last_epoch_train_loss"] = train_epoch_loss.cpu()
Expand Down
Loading