diff --git a/predict.py b/predict.py index 4d49652..fd3de3e 100644 --- a/predict.py +++ b/predict.py @@ -3,13 +3,31 @@ from datasets import load_dataset from torch.utils.data import DataLoader -from transformers import AutoProcessor, Gemma3ForConditionalGeneration +from transformers import AutoProcessor, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BlipForConditionalGeneration, Gemma3ForConditionalGeneration from config import Configuration from utils import test_collate_function, visualize_bounding_boxes +import argparse os.makedirs("outputs", exist_ok=True) +model_class_map = [ + (lambda name: "gemma" in name, Gemma3ForConditionalGeneration), + (lambda name: "blip" in name, BlipForConditionalGeneration), + (lambda name: "kimi" in name, AutoModelForCausalLM), +] + +def parse_args(): + parser = argparse.ArgumentParser(description="Fine Tune Gemma3 for Object Detection") + parser.add_argument("--model", type=str, help="Model checkpoint identifier") + return parser.parse_args() + +def get_model_class(model_name): + model_name = model_name.lower() + for condition, model_class in model_class_map: + if condition(model_name): + return model_class + return AutoModelForSeq2SeqLM def get_dataloader(processor): test_dataset = load_dataset(cfg.dataset_id, split="test") @@ -21,15 +39,20 @@ def get_dataloader(processor): ) return test_dataloader - if __name__ == "__main__": + args = parse_args() cfg = Configuration() + if args.model: + cfg.model_id = args.model + processor = AutoProcessor.from_pretrained(cfg.checkpoint_id) - model = Gemma3ForConditionalGeneration.from_pretrained( + model_class = get_model_class(cfg.model_id) + model = model_class.from_pretrained( cfg.checkpoint_id, torch_dtype=cfg.dtype, device_map="cpu", - ) + ) + model.eval() model.to(cfg.device) diff --git a/train.py b/train.py index f9c7e00..c84334c 100644 --- a/train.py +++ b/train.py @@ -5,12 +5,13 @@ import torch from datasets import load_dataset from torch.utils.data import DataLoader -from transformers import AutoProcessor, Gemma3ForConditionalGeneration +from transformers import AutoProcessor, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BlipForConditionalGeneration, Gemma3ForConditionalGeneration from config import Configuration from utils import train_collate_function import albumentations as A +import argparse logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" @@ -24,6 +25,23 @@ A.ColorJitter(p=0.2), ], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids'], filter_invalid_bboxes=True)) +model_class_map = [ + (lambda name: "gemma" in name, Gemma3ForConditionalGeneration), + (lambda name: "blip" in name, BlipForConditionalGeneration), + (lambda name: "kimi" in name, AutoModelForCausalLM), +] + +def parse_args(): + parser = argparse.ArgumentParser(description="Fine Tune Gemma3 for Object Detection") + parser.add_argument("--model", type=str, help="Model checkpoint identifier") + return parser.parse_args() + +def get_model_class(model_name): + model_name = model_name.lower() + for condition, model_class in model_class_map: + if condition(model_name): + return model_class + return AutoModelForSeq2SeqLM def get_dataloader(processor): logger.info("Fetching the dataset") @@ -61,17 +79,30 @@ def train_model(model, optimizer, cfg, train_dataloader): if __name__ == "__main__": + args = parse_args() cfg = Configuration() + if args.model: + cfg.model_id = args.model processor = AutoProcessor.from_pretrained(cfg.model_id) + model_class = get_model_class(cfg.model_id) train_dataloader = get_dataloader(processor) logger.info("Getting model & turning only attention parameters to trainable") - model = Gemma3ForConditionalGeneration.from_pretrained( - cfg.model_id, - torch_dtype=cfg.dtype, - device_map="cpu", - attn_implementation="eager", - ) + + if "gemma" in cfg.model_id.lower(): + model = model_class.from_pretrained( + cfg.model_id, + torch_dtype=cfg.dtype, + device_map="cpu", + attn_implementation="eager", + ) + else: + model = model_class.from_pretrained( + cfg.model_id, + torch_dtype=cfg.dtype, + device_map="cpu", + ) + for name, param in model.named_parameters(): if "attn" in name: param.requires_grad = True @@ -86,7 +117,7 @@ def train_model(model, optimizer, cfg, train_dataloader): optimizer = torch.optim.AdamW(params_to_train, lr=cfg.learning_rate) wandb.init( - project=cfg.project_name, + project=cfg.project_name if hasattr(cfg, "project_name") else None, name=cfg.run_name if hasattr(cfg, "run_name") else None, config=vars(cfg), )