From 162cb3fc09dd38e052acdcedaf5d8e14b45b29f8 Mon Sep 17 00:00:00 2001 From: Manpreet Singh <138612831+DevManpreet5@users.noreply.github.com> Date: Tue, 27 May 2025 10:29:46 +0530 Subject: [PATCH 1/2] Improved Training Loop in train.py --- train.py | 109 ++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 92 insertions(+), 17 deletions(-) diff --git a/train.py b/train.py index f9c7e00..5e2fe52 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,8 @@ import logging +import os import wandb from functools import partial +from pathlib import Path import torch from datasets import load_dataset @@ -8,7 +10,7 @@ from transformers import AutoProcessor, Gemma3ForConditionalGeneration from config import Configuration -from utils import train_collate_function +from utils import train_collate_function, test_collate_function import albumentations as A @@ -25,45 +27,115 @@ ], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids'], filter_invalid_bboxes=True)) -def get_dataloader(processor): - logger.info("Fetching the dataset") +def get_dataloaders(processor): + logger.info("Fetching the datasets") train_dataset = load_dataset(cfg.dataset_id, split="train") + val_dataset = load_dataset(cfg.dataset_id, split="validation") + train_collate_fn = partial( train_collate_function, processor=processor, dtype=cfg.dtype, transform=augmentations ) + val_collate_fn = partial( + test_collate_function, processor=processor, dtype=cfg.dtype + ) - logger.info("Building data loader") + logger.info("Building data loaders") train_dataloader = DataLoader( train_dataset, batch_size=cfg.batch_size, collate_fn=train_collate_fn, shuffle=True, ) - return train_dataloader + val_dataloader = DataLoader( + val_dataset, + batch_size=cfg.batch_size, + collate_fn=val_collate_fn, + shuffle=False, + ) + return train_dataloader, val_dataloader + + +def evaluate_model(model, val_dataloader, device): + model.eval() + total_loss = 0.0 + total_samples = 0 + + with torch.no_grad(): + for batch in val_dataloader: + batch = batch.to(device) + outputs = model(**batch) + loss = outputs.loss + total_loss += loss.item() * batch["input_ids"].size(0) + total_samples += batch["input_ids"].size(0) + + avg_loss = total_loss / total_samples + model.train() + return avg_loss -def train_model(model, optimizer, cfg, train_dataloader): +def train_model(model, optimizer, cfg, train_dataloader, val_dataloader): logger.info("Start training") global_step = 0 + best_val_loss = float('inf') + checkpoint_dir = Path("checkpoints") + checkpoint_dir.mkdir(exist_ok=True) + for epoch in range(cfg.epochs): + # Training loop + model.train() + train_loss = 0.0 + train_samples = 0 + for idx, batch in enumerate(train_dataloader): - outputs = model(**batch.to(model.device)) + batch = batch.to(model.device) + outputs = model(**batch) loss = outputs.loss + + train_loss += loss.item() * batch["input_ids"].size(0) + train_samples += batch["input_ids"].size(0) + if idx % 100 == 0: logger.info(f"Epoch: {epoch} Iter: {idx} Loss: {loss.item():.4f}") - wandb.log({"train/loss": loss.item(), "epoch": epoch}, step=global_step) + wandb.log({ + "train/step_loss": loss.item(), + "epoch": epoch, + "step": global_step + }) loss.backward() optimizer.step() optimizer.zero_grad() global_step += 1 + + avg_train_loss = train_loss / train_samples + wandb.log({"train/epoch_loss": avg_train_loss, "epoch": epoch}) + + val_loss = evaluate_model(model, val_dataloader, cfg.device) + wandb.log({"val/loss": val_loss, "epoch": epoch}) + logger.info(f"Epoch: {epoch} Train Loss: {avg_train_loss:.4f} Val Loss: {val_loss:.4f}") + + if val_loss < best_val_loss: + best_val_loss = val_loss + checkpoint_path = checkpoint_dir / f"best_model_epoch_{epoch}.pt" + torch.save({ + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': val_loss, + }, checkpoint_path) + logger.info(f"New best model saved at {checkpoint_path} with val loss {val_loss:.4f}") + + if epoch % cfg.save_every == 0: + model.push_to_hub(cfg.checkpoint_id, commit_message=f"Epoch {epoch} - Val loss {val_loss:.4f}") + processor.push_to_hub(cfg.checkpoint_id) + return model if __name__ == "__main__": cfg = Configuration() processor = AutoProcessor.from_pretrained(cfg.model_id) - train_dataloader = get_dataloader(processor) + train_dataloader, val_dataloader = get_dataloaders(processor) logger.info("Getting model & turning only attention parameters to trainable") model = Gemma3ForConditionalGeneration.from_pretrained( @@ -91,11 +163,14 @@ def train_model(model, optimizer, cfg, train_dataloader): config=vars(cfg), ) - train_model(model, optimizer, cfg, train_dataloader) - - # Push the checkpoint to hub - model.push_to_hub(cfg.checkpoint_id) - processor.push_to_hub(cfg.checkpoint_id) - - wandb.finish() - logger.info("Train finished") + try: + train_model(model, optimizer, cfg, train_dataloader, val_dataloader) + except Exception as e: + logger.error(f"Training failed: {str(e)}") + raise + finally: + # Push the final checkpoint to hub + model.push_to_hub(cfg.checkpoint_id, commit_message="Final model") + processor.push_to_hub(cfg.checkpoint_id) + wandb.finish() + logger.info("Training finished") From bbf31ab6e78c2c85130405bc2b9b2c279e86c75e Mon Sep 17 00:00:00 2001 From: Manpreet Singh <138612831+DevManpreet5@users.noreply.github.com> Date: Tue, 27 May 2025 10:30:22 +0530 Subject: [PATCH 2/2] added project_name for w&b tracking --- config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/config.py b/config.py index 5461122..a62e70a 100644 --- a/config.py +++ b/config.py @@ -16,3 +16,6 @@ class Configuration: batch_size: int = 8 learning_rate: float = 2e-05 epochs = 2 + save_every: int = 1 + project_name: str = "gemma-object-detection" + run_name: str = "exp1"