From 6d4266256f24748e8940276d7d962496f88a2e0e Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Mon, 16 Jun 2025 10:09:27 -0400 Subject: [PATCH 01/22] initial changes [untested] --- configs/config.yaml | 18 ++++++ configs/lora_config.yaml | 12 ++++ predict.py | 6 +- train.py | 93 +++++++++++++++++++----------- utils/__init__.py | 0 utils/config.py | 106 ++++++++++++++++++++++++++++++++++ utils/create_dataset.py | 58 +++++++++++++++++++ utils/utilities.py | 121 +++++++++++++++++++++++++++++++++++++++ 8 files changed, 377 insertions(+), 37 deletions(-) create mode 100644 configs/config.yaml create mode 100644 configs/lora_config.yaml create mode 100644 utils/__init__.py create mode 100644 utils/config.py create mode 100644 utils/create_dataset.py create mode 100644 utils/utilities.py diff --git a/configs/config.yaml b/configs/config.yaml new file mode 100644 index 0000000..3279b6c --- /dev/null +++ b/configs/config.yaml @@ -0,0 +1,18 @@ +dataset_id: "ariG23498/license-detection-paligemma" +model_id: "google/gemma-3-1b-pt" +checkpoint_id: "sergiopaniego/gemma-3-4b-pt-object-detection-aug" + +device: "cuda" +dtype: "bfloat16" + +batch_size: 16 +learning_rate: 2e-5 +epochs: 2 + +finetune_method: "lora" # FFT | lora | qlora +use_unsloth: false + +mm_tunable_parts: + - multi_modal_projector + # - vision_tower + # - language_model \ No newline at end of file diff --git a/configs/lora_config.yaml b/configs/lora_config.yaml new file mode 100644 index 0000000..abc74fb --- /dev/null +++ b/configs/lora_config.yaml @@ -0,0 +1,12 @@ +r: 32 +alpha: 64 +dropout: 0.05 +target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - up_proj + - down_proj + - gate_proj +max_seq_length: 2048 # Unsloth will RoPE-scale \ No newline at end of file diff --git a/predict.py b/predict.py index 4d49652..c8f1303 100644 --- a/predict.py +++ b/predict.py @@ -5,8 +5,8 @@ from torch.utils.data import DataLoader from transformers import AutoProcessor, Gemma3ForConditionalGeneration -from config import Configuration -from utils import test_collate_function, visualize_bounding_boxes +from utils.config import Configuration +from utils.utilities import test_collate_function, visualize_bounding_boxes os.makedirs("outputs", exist_ok=True) @@ -23,7 +23,7 @@ def get_dataloader(processor): if __name__ == "__main__": - cfg = Configuration() + cfg = Configuration.from_args() processor = AutoProcessor.from_pretrained(cfg.checkpoint_id) model = Gemma3ForConditionalGeneration.from_pretrained( cfg.checkpoint_id, diff --git a/train.py b/train.py index 8aab73d..db51689 100644 --- a/train.py +++ b/train.py @@ -3,12 +3,13 @@ from functools import partial import torch +from torch.amp import autocast, GradScaler from datasets import load_dataset from torch.utils.data import DataLoader from transformers import AutoProcessor, Gemma3ForConditionalGeneration -from config import Configuration -from utils import train_collate_function +from utils.config import Configuration +from utils.utilities import train_collate_function import argparse import albumentations as A @@ -42,38 +43,57 @@ def get_dataloader(processor, args, dtype): return train_dataloader -def train_model(model, optimizer, cfg, train_dataloader): +def train_model(model, optimizer, cfg:Configuration, train_dataloader): logger.info("Start training") global_step = 0 + + + use_fp16 = False + if cfg.dtype in ["float16", "bfloat16"]: + scaler = GradScaler() + use_fp16 = True + for epoch in range(cfg.epochs): for idx, batch in enumerate(train_dataloader): - outputs = model(**batch.to(model.device)) - loss = outputs.loss + optimizer.zero_grad() # zero grad before every batch + + if use_fp16: + with autocast(device_type=cfg.device): + outputs = model(**batch.to(model.device)) + loss = outputs.loss + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + outputs = model(**batch.to(model.device)) + loss = outputs.loss + loss.backward() + optimizer.step() + if idx % 100 == 0: logger.info(f"Epoch: {epoch} Iter: {idx} Loss: {loss.item():.4f}") wandb.log({"train/loss": loss.item(), "epoch": epoch}, step=global_step) - - loss.backward() - optimizer.step() - optimizer.zero_grad() global_step += 1 + return model if __name__ == "__main__": + # 1. Parse CLI + YAMLs into config cfg = Configuration.from_args() - # Get values dynamicaly from user - parser = argparse.ArgumentParser(description="Training for PaLiGemma") - parser.add_argument('--model_id', type=str, required=True, default=cfg.model_id, help='Enter Huggingface Model ID') - parser.add_argument('--dataset_id', type=str, required=True ,default=cfg.dataset_id, help='Enter Huggingface Dataset ID') - parser.add_argument('--batch_size', type=int, default=cfg.batch_size, help='Enter Batch Size') - parser.add_argument('--lr', type=float, default=cfg.learning_rate, help='Enter Learning Rate') - parser.add_argument('--checkpoint_id', type=str, required=True, default=cfg.checkpoint_id, help='Enter Huggingface Repo ID to push model') + # # Get values dynamicaly from user + # parser = argparse.ArgumentParser(description="Training for PaLiGemma") + # parser.add_argument('--model_id', type=str, required=True, default=cfg.model_id, help='Enter Huggingface Model ID') + # parser.add_argument('--dataset_id', type=str, required=True ,default=cfg.dataset_id, help='Enter Huggingface Dataset ID') + # parser.add_argument('--batch_size', type=int, default=cfg.batch_size, help='Enter Batch Size') + # parser.add_argument('--lr', type=float, default=cfg.learning_rate, help='Enter Learning Rate') + # parser.add_argument('--checkpoint_id', type=str, required=True, default=cfg.checkpoint_id, help='Enter Huggingface Repo ID to push model') - args = parser.parse_args() - processor = AutoProcessor.from_pretrained(args.model_id) - train_dataloader = get_dataloader(processor=processor, args=args, dtype=cfg.dtype) + # args = parser.parse_args() + processor = AutoProcessor.from_pretrained(cfg.model_id) + train_dataloader = get_dataloader(processor=processor, args=cfg, dtype=cfg.dtype) logger.info("Getting model & turning only attention parameters to trainable") model = Gemma3ForConditionalGeneration.from_pretrained( @@ -82,30 +102,35 @@ def train_model(model, optimizer, cfg, train_dataloader): device_map="cpu", attn_implementation="eager", ) - for name, param in model.named_parameters(): - if "attn" in name: - param.requires_grad = True - else: - param.requires_grad = False + + # No need to finetune entire model (especially base language model) just use cfg.mm_tunable_parts + # for name, param in model.named_parameters(): + # if "attn" in name: + # param.requires_grad = True + # else: + # param.requires_grad = False + for layer_name, param in model.named_parameters(): + param.requires_grad = any(tune_part in layer_name for tune_part in cfg.mm_tunable_parts) + model.train() model.to(cfg.device) # Credits to Sayak Paul for this beautiful expression params_to_train = list(filter(lambda x: x.requires_grad, model.parameters())) - optimizer = torch.optim.AdamW(params_to_train, lr=args.lr) + optimizer = torch.optim.AdamW(params_to_train, lr=cfg.learning_rate) - wandb.init( - project=cfg.project_name, - name=cfg.run_name if hasattr(cfg, "run_name") else None, - config=vars(cfg), - ) + # wandb.init( + # project=cfg.project_name, + # name=cfg.run_name if hasattr(cfg, "run_name") else None, + # config=vars(cfg), + # ) train_model(model, optimizer, cfg, train_dataloader) - # Push the checkpoint to hub - model.push_to_hub(cfg.checkpoint_id) - processor.push_to_hub(cfg.checkpoint_id) + # # Push the checkpoint to hub + # model.push_to_hub(cfg.checkpoint_id) + # processor.push_to_hub(cfg.checkpoint_id) - wandb.finish() + # wandb.finish() logger.info("Train finished") diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/config.py b/utils/config.py new file mode 100644 index 0000000..5833719 --- /dev/null +++ b/utils/config.py @@ -0,0 +1,106 @@ +import argparse +import torch +from dataclasses import dataclass, field +from typing import List +from omegaconf import OmegaConf +import os + + +def str2bool(v): + if isinstance(v, bool): return v + if v.lower() in ('yes', 'true', 't', '1'): return True + if v.lower() in ('no', 'false', 'f', '0'): return False + raise argparse.ArgumentTypeError("Boolean value expected.") + + +@dataclass +class LoRAConfig: + r: int = 32 + alpha: int = 64 + dropout: float = 0.05 + target_modules: List[str] = field(default_factory=lambda: [ + "q_proj", "k_proj", "v_proj", "o_proj", + "up_proj", "down_proj", "gate_proj" + ]) + max_seq_length: int = 2048 + + +@dataclass +class Configuration: + dataset_id: str = "ariG23498/license-detection-paligemma" + model_id: str = "google/gemma-3-4b-pt" + checkpoint_id: str = "sergiopaniego/gemma-3-4b-pt-object-detection-aug" + device: str = "cuda" if torch.cuda.is_available() else "cpu" + dtype: torch.dtype = torch.bfloat16 + + batch_size: int = 16 + learning_rate: float = 2e-5 + epochs: int = 2 + + finetune_method: str = "FFT" # FFT | lora | qlora + use_unsloth: bool = False + mm_tunable_parts: List[str] = field(default_factory=lambda: ["multi_modal_projector"]) # vision_tower,language_model + lora: LoRAConfig = field(default_factory=LoRAConfig) + + @classmethod + def load(cls, main_cfg_path="configs/config.yaml", lora_cfg_path="configs/lora_config.yaml"): + base_cfg = OmegaConf.load(main_cfg_path) + lora_cfg = OmegaConf.load(lora_cfg_path) + base_cfg.lora = lora_cfg + return OmegaConf.to_container(base_cfg, resolve=True) + + @classmethod + def from_args(cls): + cfg_dict = cls.load() # Load YAML as dict + parser = argparse.ArgumentParser() + + # Top-level args + parser.add_argument("--dataset_id", type=str, default=cfg_dict["dataset_id"]) + parser.add_argument("--model_id", type=str, default=cfg_dict["model_id"]) + parser.add_argument("--checkpoint_id", type=str, default=cfg_dict["checkpoint_id"]) + parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default=cfg_dict["device"]) + parser.add_argument("--dtype", type=str, choices=["float32", "float16", "bfloat16"], default="bfloat16") + parser.add_argument("--batch_size", type=int, default=cfg_dict["batch_size"]) + parser.add_argument("--learning_rate", type=float, default=cfg_dict["learning_rate"]) + parser.add_argument("--epochs", type=int, default=cfg_dict["epochs"]) + parser.add_argument("--finetune_method", type=str, choices=["FFT", "lora", "qlora"], default=cfg_dict["finetune_method"]) + parser.add_argument("--use_unsloth", type=str2bool, default=cfg_dict["use_unsloth"]) + parser.add_argument("--mm_tunable_parts", type=str, default=",".join(cfg_dict["mm_tunable_parts"])) + + # LoRA nested config overrides + parser.add_argument("--lora.r", type=int, default=cfg_dict["lora"]["r"]) + parser.add_argument("--lora.alpha", type=int, default=cfg_dict["lora"]["alpha"]) + parser.add_argument("--lora.dropout", type=float, default=cfg_dict["lora"]["dropout"]) + parser.add_argument("--lora.target_modules", type=str, default=",".join(cfg_dict["lora"]["target_modules"])) + parser.add_argument("--lora.max_seq_length", type=int, default=cfg_dict["lora"]["max_seq_length"]) + + args = parser.parse_args() + + dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + + lora_config = LoRAConfig( + r=args.__dict__["lora.r"], + alpha=args.__dict__["lora.alpha"], + dropout=args.__dict__["lora.dropout"], + target_modules=[x.strip() for x in args.__dict__["lora.target_modules"].split(',')], + max_seq_length=args.__dict__["lora.max_seq_length"], + ) + + return cls( + dataset_id=args.dataset_id, + model_id=args.model_id, + checkpoint_id=args.checkpoint_id, + device=args.device, + dtype=dtype_map[args.dtype], + batch_size=args.batch_size, + learning_rate=args.learning_rate, + epochs=args.epochs, + finetune_method=args.finetune_method, + use_unsloth=args.use_unsloth, + mm_tunable_parts=[x.strip() for x in args.mm_tunable_parts.split(',')], + lora=lora_config, + ) \ No newline at end of file diff --git a/utils/create_dataset.py b/utils/create_dataset.py new file mode 100644 index 0000000..825a5b3 --- /dev/null +++ b/utils/create_dataset.py @@ -0,0 +1,58 @@ +from datasets import load_dataset +import argparse + +def coco_to_xyxy(coco_bbox): + x, y, width, height = coco_bbox + x1, y1 = x, y + x2, y2 = x + width, y + height + return [x1, y1, x2, y2] + + +def convert_to_detection_string(bboxs, image_width, image_height): + def format_location(value, max_value): + return f"" + + detection_strings = [] + for bbox in bboxs: + x1, y1, x2, y2 = coco_to_xyxy(bbox) + name = "plate" + locs = [ + format_location(y1, image_height), + format_location(x1, image_width), + format_location(y2, image_height), + format_location(x2, image_width), + ] + detection_string = "".join(locs) + f" {name}" + detection_strings.append(detection_string) + + return " ; ".join(detection_strings) + + +def format_objects(example): + height = example["height"] + width = example["width"] + bboxs = example["objects"]["bbox"] + formatted_objects = convert_to_detection_string(bboxs, width, height) + return {"label_for_paligemma": formatted_objects} + + +if __name__ == "__main__": + from utils.config import Configuration # To avoid circular import error + + # Support for generic script for dataset + cfg = Configuration() + parser = argparse.ArgumentParser(description='Process dataset for PaLiGemma') + parser.add_argument('--dataset', type=str, required=True, default=cfg.dataset_id, help='Hugging Face dataset ID') + parser.add_argument('--output_repo', type=str, required=True, help='Output repository ID for Hugging Face Hub') + args = parser.parse_args() + + # load the dataset + print(f"[INFO] Loading {args.dataset} from hub...") + dataset = load_dataset(args.dataset, args.config) if args.config else load_dataset(args.dataset) + + for split in dataset.keys(): + print(f"[INFO] Processing split: {split}") + dataset[split] = dataset[split].map(format_objects) + + # push to hub + dataset.push_to_hub(args.output_repo) diff --git a/utils/utilities.py b/utils/utilities.py new file mode 100644 index 0000000..7fe3fde --- /dev/null +++ b/utils/utilities.py @@ -0,0 +1,121 @@ +import re +import argparse +import matplotlib.pyplot as plt +import numpy as np +from PIL import ImageDraw + +from utils.create_dataset import format_objects + +def parse_paligemma_label(label, width, height): + # Extract location codes + loc_pattern = r"" + locations = [int(loc) for loc in re.findall(loc_pattern, label)] + + # Extract category (everything after the last location code) + category = label.split(">")[-1].strip() + + # Convert normalized locations back to original image coordinates + # Order in PaliGemma format is: y1, x1, y2, x2 + y1_norm, x1_norm, y2_norm, x2_norm = locations + + # Convert normalized coordinates to actual coordinates + x1 = (x1_norm / 1024) * width + y1 = (y1_norm / 1024) * height + x2 = (x2_norm / 1024) * width + y2 = (y2_norm / 1024) * height + + return category, [x1, y1, x2, y2] + + +def visualize_bounding_boxes(image, label, width, height, name): + # Create a copy of the image to draw on + draw_image = image.copy() + draw = ImageDraw.Draw(draw_image) + + # Parse the label + category, bbox = parse_paligemma_label(label, width, height) + + # Draw the bounding box + draw.rectangle(bbox, outline="red", width=2) + + # Add category label + draw.text((bbox[0], max(0, bbox[1] - 10)), category, fill="red") + + # Show the image + plt.figure(figsize=(10, 6)) + plt.imshow(draw_image) + plt.axis("off") + plt.title(f"Bounding Box: {category}") + plt.tight_layout() + plt.savefig(name) + plt.show() + plt.close() + + +def train_collate_function(batch_of_samples, processor, dtype, transform=None): + images = [] + prompts = [] + for sample in batch_of_samples: + if transform: + transformed = transform(image=np.array(sample["image"]), bboxes=sample["objects"]["bbox"], category_ids=sample["objects"]["category"]) + sample["image"] = transformed["image"] + sample["objects"]["bbox"] = transformed["bboxes"] + sample["objects"]["category"] = transformed["category_ids"] + sample["height"] = sample["image"].shape[0] + sample["width"] = sample["image"].shape[1] + sample['label_for_paligemma'] = format_objects(sample)['label_for_paligemma'] + images.append([sample["image"]]) + prompts.append( + f"{processor.tokenizer.boi_token} detect \n\n{sample['label_for_paligemma']} {processor.tokenizer.eos_token}" + ) + + batch = processor(images=images, text=prompts, return_tensors="pt", padding=True) + + # The labels are the input_ids, and we mask the padding tokens in the loss computation + labels = batch["input_ids"].clone() # Clone input IDs for labels + + # List from https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora + # Mask image tokens + image_token_id = [ + processor.tokenizer.convert_tokens_to_ids( + processor.tokenizer.special_tokens_map["boi_token"] + ) + ] + # Mask tokens for not being used in the loss computation + labels[labels == processor.tokenizer.pad_token_id] = -100 + labels[labels == image_token_id] = -100 + labels[labels == 262144] = -100 + + batch["labels"] = labels + + batch["pixel_values"] = batch["pixel_values"].to( + dtype + ) # to check with the implementation + return batch + + +def test_collate_function(batch_of_samples, processor, dtype): + images = [] + prompts = [] + for sample in batch_of_samples: + images.append([sample["image"]]) + prompts.append(f"{processor.tokenizer.boi_token} detect \n\n") + + batch = processor(images=images, text=prompts, return_tensors="pt", padding=True) + batch["pixel_values"] = batch["pixel_values"].to( + dtype + ) # to check with the implementation + return batch, images + +def str2bool(v): + """ + Helper function to parse boolean values from cli arguments + """ + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') \ No newline at end of file From b0b79fda84e790b3f71de555c65db2a2bab75ff0 Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Mon, 16 Jun 2025 10:11:39 -0400 Subject: [PATCH 02/22] fixed device map to auto --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index db51689..76b865d 100644 --- a/train.py +++ b/train.py @@ -56,7 +56,7 @@ def train_model(model, optimizer, cfg:Configuration, train_dataloader): for epoch in range(cfg.epochs): for idx, batch in enumerate(train_dataloader): optimizer.zero_grad() # zero grad before every batch - + if use_fp16: with autocast(device_type=cfg.device): outputs = model(**batch.to(model.device)) @@ -99,7 +99,7 @@ def train_model(model, optimizer, cfg:Configuration, train_dataloader): model = Gemma3ForConditionalGeneration.from_pretrained( cfg.model_id, torch_dtype=cfg.dtype, - device_map="cpu", + device_map="auto", attn_implementation="eager", ) From e239b6c90599f7f1551af80184f7f74b7f5adc91 Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Mon, 16 Jun 2025 10:13:36 -0400 Subject: [PATCH 03/22] moved root files to modules --- config.py | 19 -------- create_dataset.py | 57 ------------------------ utils.py | 108 ---------------------------------------------- 3 files changed, 184 deletions(-) delete mode 100644 config.py delete mode 100644 create_dataset.py delete mode 100644 utils.py diff --git a/config.py b/config.py deleted file mode 100644 index 333d864..0000000 --- a/config.py +++ /dev/null @@ -1,19 +0,0 @@ -from dataclasses import dataclass - -import torch - - -@dataclass -class Configuration: - dataset_id: str = "ariG23498/license-detection-paligemma" - - model_id: str = "google/gemma-3-4b-pt" - checkpoint_id: str = "sergiopaniego/gemma-3-4b-pt-object-detection-aug" - - device: str = "cuda" if torch.cuda.is_available() else "cpu" - dtype: torch.dtype = torch.bfloat16 - - batch_size: int = 8 - learning_rate: float = 2e-05 - epochs = 2 - diff --git a/create_dataset.py b/create_dataset.py deleted file mode 100644 index 6dce684..0000000 --- a/create_dataset.py +++ /dev/null @@ -1,57 +0,0 @@ -from datasets import load_dataset -import argparse -from config import Configuration - -def coco_to_xyxy(coco_bbox): - x, y, width, height = coco_bbox - x1, y1 = x, y - x2, y2 = x + width, y + height - return [x1, y1, x2, y2] - - -def convert_to_detection_string(bboxs, image_width, image_height): - def format_location(value, max_value): - return f"" - - detection_strings = [] - for bbox in bboxs: - x1, y1, x2, y2 = coco_to_xyxy(bbox) - name = "plate" - locs = [ - format_location(y1, image_height), - format_location(x1, image_width), - format_location(y2, image_height), - format_location(x2, image_width), - ] - detection_string = "".join(locs) + f" {name}" - detection_strings.append(detection_string) - - return " ; ".join(detection_strings) - - -def format_objects(example): - height = example["height"] - width = example["width"] - bboxs = example["objects"]["bbox"] - formatted_objects = convert_to_detection_string(bboxs, width, height) - return {"label_for_paligemma": formatted_objects} - - -if __name__ == "__main__": - # Support for generic script for dataset - cfg = Configuration() - parser = argparse.ArgumentParser(description='Process dataset for PaLiGemma') - parser.add_argument('--dataset', type=str, required=True, default=cfg.dataset_id, help='Hugging Face dataset ID') - parser.add_argument('--output_repo', type=str, required=True, help='Output repository ID for Hugging Face Hub') - args = parser.parse_args() - - # load the dataset - print(f"[INFO] Loading {args.dataset} from hub...") - dataset = load_dataset(args.dataset, args.config) if args.config else load_dataset(args.dataset) - - for split in dataset.keys(): - print(f"[INFO] Processing split: {split}") - dataset[split] = dataset[split].map(format_objects) - - # push to hub - dataset.push_to_hub(args.output_repo) diff --git a/utils.py b/utils.py deleted file mode 100644 index 2cee681..0000000 --- a/utils.py +++ /dev/null @@ -1,108 +0,0 @@ -import re - -import matplotlib.pyplot as plt -import numpy as np -from PIL import ImageDraw - -from create_dataset import format_objects - -def parse_paligemma_label(label, width, height): - # Extract location codes - loc_pattern = r"" - locations = [int(loc) for loc in re.findall(loc_pattern, label)] - - # Extract category (everything after the last location code) - category = label.split(">")[-1].strip() - - # Convert normalized locations back to original image coordinates - # Order in PaliGemma format is: y1, x1, y2, x2 - y1_norm, x1_norm, y2_norm, x2_norm = locations - - # Convert normalized coordinates to actual coordinates - x1 = (x1_norm / 1024) * width - y1 = (y1_norm / 1024) * height - x2 = (x2_norm / 1024) * width - y2 = (y2_norm / 1024) * height - - return category, [x1, y1, x2, y2] - - -def visualize_bounding_boxes(image, label, width, height, name): - # Create a copy of the image to draw on - draw_image = image.copy() - draw = ImageDraw.Draw(draw_image) - - # Parse the label - category, bbox = parse_paligemma_label(label, width, height) - - # Draw the bounding box - draw.rectangle(bbox, outline="red", width=2) - - # Add category label - draw.text((bbox[0], max(0, bbox[1] - 10)), category, fill="red") - - # Show the image - plt.figure(figsize=(10, 6)) - plt.imshow(draw_image) - plt.axis("off") - plt.title(f"Bounding Box: {category}") - plt.tight_layout() - plt.savefig(name) - plt.show() - plt.close() - - -def train_collate_function(batch_of_samples, processor, dtype, transform=None): - images = [] - prompts = [] - for sample in batch_of_samples: - if transform: - transformed = transform(image=np.array(sample["image"]), bboxes=sample["objects"]["bbox"], category_ids=sample["objects"]["category"]) - sample["image"] = transformed["image"] - sample["objects"]["bbox"] = transformed["bboxes"] - sample["objects"]["category"] = transformed["category_ids"] - sample["height"] = sample["image"].shape[0] - sample["width"] = sample["image"].shape[1] - sample['label_for_paligemma'] = format_objects(sample)['label_for_paligemma'] - images.append([sample["image"]]) - prompts.append( - f"{processor.tokenizer.boi_token} detect \n\n{sample['label_for_paligemma']} {processor.tokenizer.eos_token}" - ) - - batch = processor(images=images, text=prompts, return_tensors="pt", padding=True) - - # The labels are the input_ids, and we mask the padding tokens in the loss computation - labels = batch["input_ids"].clone() # Clone input IDs for labels - - # List from https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora - # Mask image tokens - image_token_id = [ - processor.tokenizer.convert_tokens_to_ids( - processor.tokenizer.special_tokens_map["boi_token"] - ) - ] - # Mask tokens for not being used in the loss computation - labels[labels == processor.tokenizer.pad_token_id] = -100 - labels[labels == image_token_id] = -100 - labels[labels == 262144] = -100 - - batch["labels"] = labels - - batch["pixel_values"] = batch["pixel_values"].to( - dtype - ) # to check with the implementation - return batch - - -def test_collate_function(batch_of_samples, processor, dtype): - images = [] - prompts = [] - for sample in batch_of_samples: - images.append([sample["image"]]) - prompts.append(f"{processor.tokenizer.boi_token} detect \n\n") - - batch = processor(images=images, text=prompts, return_tensors="pt", padding=True) - batch["pixel_values"] = batch["pixel_values"].to( - dtype - ) # to check with the implementation - return batch, images From 3311eb6643bacd0d1551eee8218821c032ca0ce1 Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Mon, 16 Jun 2025 11:15:39 -0400 Subject: [PATCH 04/22] updated lora training code --- configs/config.yaml | 8 +++--- train.py | 70 ++++++++++++++++++++++++++------------------- utils/config.py | 2 +- 3 files changed, 45 insertions(+), 35 deletions(-) diff --git a/configs/config.yaml b/configs/config.yaml index 3279b6c..b79ab8b 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -1,15 +1,15 @@ dataset_id: "ariG23498/license-detection-paligemma" -model_id: "google/gemma-3-1b-pt" +model_id: "google/gemma-3-4b-pt" checkpoint_id: "sergiopaniego/gemma-3-4b-pt-object-detection-aug" device: "cuda" -dtype: "bfloat16" +dtype: "float16" -batch_size: 16 +batch_size: 1 learning_rate: 2e-5 epochs: 2 -finetune_method: "lora" # FFT | lora | qlora +finetune_method: "qlora" # FFT | lora | qlora use_unsloth: false mm_tunable_parts: diff --git a/train.py b/train.py index 76b865d..bbea2ff 100644 --- a/train.py +++ b/train.py @@ -10,7 +10,7 @@ from utils.config import Configuration from utils.utilities import train_collate_function -import argparse +from peft import get_peft_config, get_peft_model, LoraConfig import albumentations as A logging.basicConfig( @@ -44,15 +44,16 @@ def get_dataloader(processor, args, dtype): def train_model(model, optimizer, cfg:Configuration, train_dataloader): - logger.info("Start training") global_step = 0 - - use_fp16 = False - if cfg.dtype in ["float16", "bfloat16"]: + if cfg.dtype in [torch.float16, torch.bfloat16]: scaler = GradScaler() use_fp16 = True + logger.info("using fp16 to scale loss") + else: + logger.info(f"Found dtype: {cfg.dtype}") + logger.info("Start training") for epoch in range(cfg.epochs): for idx, batch in enumerate(train_dataloader): optimizer.zero_grad() # zero grad before every batch @@ -79,40 +80,49 @@ def train_model(model, optimizer, cfg:Configuration, train_dataloader): return model -if __name__ == "__main__": - # 1. Parse CLI + YAMLs into config - cfg = Configuration.from_args() - - # # Get values dynamicaly from user - # parser = argparse.ArgumentParser(description="Training for PaLiGemma") - # parser.add_argument('--model_id', type=str, required=True, default=cfg.model_id, help='Enter Huggingface Model ID') - # parser.add_argument('--dataset_id', type=str, required=True ,default=cfg.dataset_id, help='Enter Huggingface Dataset ID') - # parser.add_argument('--batch_size', type=int, default=cfg.batch_size, help='Enter Batch Size') - # parser.add_argument('--lr', type=float, default=cfg.learning_rate, help='Enter Learning Rate') - # parser.add_argument('--checkpoint_id', type=str, required=True, default=cfg.checkpoint_id, help='Enter Huggingface Repo ID to push model') - - # args = parser.parse_args() - processor = AutoProcessor.from_pretrained(cfg.model_id) - train_dataloader = get_dataloader(processor=processor, args=cfg, dtype=cfg.dtype) +def load_model(cfg:Configuration): - logger.info("Getting model & turning only attention parameters to trainable") model = Gemma3ForConditionalGeneration.from_pretrained( cfg.model_id, torch_dtype=cfg.dtype, - device_map="auto", + device_map="cpu", attn_implementation="eager", ) - # No need to finetune entire model (especially base language model) just use cfg.mm_tunable_parts - # for name, param in model.named_parameters(): - # if "attn" in name: - # param.requires_grad = True - # else: - # param.requires_grad = False - for layer_name, param in model.named_parameters(): - param.requires_grad = any(tune_part in layer_name for tune_part in cfg.mm_tunable_parts) + if cfg.finetune_method in {"lora", "qlora"}: + lcfg = cfg.lora + lora_cfg = LoraConfig( + r=lcfg.r, + lora_alpha=lcfg.alpha, + target_modules=lcfg.target_modules, + lora_dropout=lcfg.dropout, + bias="none", + ) + + model = get_peft_model(model, lora_cfg) + model.print_trainable_parameters() + + elif cfg.finetune_method == "FFT": + # Only unfreeze requested model parts (e.g. multi_modal_projector) + for n, p in model.named_parameters(): + p.requires_grad = any(part in n for part in cfg.mm_tunable_parts) + print(f"{n} will be finetuned") + else: + raise ValueError(f"Unknown finetune_method: {cfg.finetune_method}") + + return model + +if __name__ == "__main__": + # 1. Parse CLI + YAMLs into config + cfg = Configuration.from_args() + processor = AutoProcessor.from_pretrained(cfg.model_id) + train_dataloader = get_dataloader(processor=processor, args=cfg, dtype=cfg.dtype) + + logger.info("Getting model & turning only attention parameters to trainable") + model = load_model(cfg) + model.train() model.to(cfg.device) diff --git a/utils/config.py b/utils/config.py index 5833719..1665a9b 100644 --- a/utils/config.py +++ b/utils/config.py @@ -59,7 +59,7 @@ def from_args(cls): parser.add_argument("--model_id", type=str, default=cfg_dict["model_id"]) parser.add_argument("--checkpoint_id", type=str, default=cfg_dict["checkpoint_id"]) parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default=cfg_dict["device"]) - parser.add_argument("--dtype", type=str, choices=["float32", "float16", "bfloat16"], default="bfloat16") + parser.add_argument("--dtype", type=str, choices=["float32", "float16", "bfloat16"], default="float16") parser.add_argument("--batch_size", type=int, default=cfg_dict["batch_size"]) parser.add_argument("--learning_rate", type=float, default=cfg_dict["learning_rate"]) parser.add_argument("--epochs", type=int, default=cfg_dict["epochs"]) From 835695946cbcd3d59c5f4cfe88ed4b8c217a7c3f Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Mon, 16 Jun 2025 11:47:24 -0400 Subject: [PATCH 05/22] updated quantization option for qlora --- train.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index bbea2ff..d2e2098 100644 --- a/train.py +++ b/train.py @@ -7,6 +7,7 @@ from datasets import load_dataset from torch.utils.data import DataLoader from transformers import AutoProcessor, Gemma3ForConditionalGeneration +from transformers import BitsAndBytesConfig from utils.config import Configuration from utils.utilities import train_collate_function @@ -82,14 +83,33 @@ def train_model(model, optimizer, cfg:Configuration, train_dataloader): def load_model(cfg:Configuration): + bnb_config = None + quant_args = {} + + if cfg.finetune_method == "qlora": + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=cfg.dtype, + ) + quant_args.update({ + "quantization_config": bnb_config, + "device_map": "auto", + }) + logger.info("Loaded model in 4-bit with bitsandbytes") + model = Gemma3ForConditionalGeneration.from_pretrained( cfg.model_id, torch_dtype=cfg.dtype, - device_map="cpu", attn_implementation="eager", + **quant_args, ) if cfg.finetune_method in {"lora", "qlora"}: + for n, p in model.named_parameters(): + p.requires_grad = False + lcfg = cfg.lora lora_cfg = LoraConfig( r=lcfg.r, @@ -97,6 +117,7 @@ def load_model(cfg:Configuration): target_modules=lcfg.target_modules, lora_dropout=lcfg.dropout, bias="none", + task_type="CAUSAL_LM", ) model = get_peft_model(model, lora_cfg) From c8eea4e6b313cc65888e9db0afdc119ba543d1cb Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Mon, 16 Jun 2025 11:58:08 -0400 Subject: [PATCH 06/22] updated quantization option for qlora --- train.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/train.py b/train.py index d2e2098..0e2f2ca 100644 --- a/train.py +++ b/train.py @@ -83,26 +83,22 @@ def train_model(model, optimizer, cfg:Configuration, train_dataloader): def load_model(cfg:Configuration): - bnb_config = None - quant_args = {} - - if cfg.finetune_method == "qlora": - bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=cfg.dtype, - ) - quant_args.update({ - "quantization_config": bnb_config, - "device_map": "auto", - }) - logger.info("Loaded model in 4-bit with bitsandbytes") + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=cfg.dtype, + ) + quant_args = { + "quantization_config": bnb_config, + "device_map": "auto", + } model = Gemma3ForConditionalGeneration.from_pretrained( cfg.model_id, torch_dtype=cfg.dtype, attn_implementation="eager", + device_map="auto", **quant_args, ) @@ -122,6 +118,7 @@ def load_model(cfg:Configuration): model = get_peft_model(model, lora_cfg) model.print_trainable_parameters() + torch.cuda.empty_cache() elif cfg.finetune_method == "FFT": # Only unfreeze requested model parts (e.g. multi_modal_projector) From d8cebfb541276acac81c9681beba6be53165cab2 Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Mon, 16 Jun 2025 12:26:42 -0400 Subject: [PATCH 07/22] added unsloth changes --- train.py | 154 +++++++++++++++++++++++++++++++-------------- utils/utilities.py | 48 ++++++++++++++ 2 files changed, 154 insertions(+), 48 deletions(-) diff --git a/train.py b/train.py index 0e2f2ca..33401a1 100644 --- a/train.py +++ b/train.py @@ -1,3 +1,9 @@ +# Optional – Unsloth is only imported if the flag is set at runtime +try: + from unsloth import FastModel +except ImportError: + FastModel = None # will be checked at runtime + import logging import wandb from functools import partial @@ -10,7 +16,7 @@ from transformers import BitsAndBytesConfig from utils.config import Configuration -from utils.utilities import train_collate_function +from utils.utilities import train_collate_function, train_collate_function_unsloth from peft import get_peft_config, get_peft_model, LoraConfig import albumentations as A @@ -26,8 +32,26 @@ A.ColorJitter(p=0.2), ], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids'], filter_invalid_bboxes=True)) +def get_dataloader_unsloth(tokenizer, args, dtype): + logger.info("Fetching the dataset") + train_dataset = load_dataset(args.dataset_id, split="train") # or cfg.dataset_id + train_collate_fn = partial( + train_collate_function_unsloth, + tokenizer=tokenizer, # <- Use the Unsloth tokenizer + dtype=dtype, + transform=augmentations + ) -def get_dataloader(processor, args, dtype): + logger.info("Building data loader") + train_dataloader = DataLoader( + train_dataset, + batch_size=args.batch_size, + collate_fn=train_collate_fn, + shuffle=True, + ) + return train_dataloader + +def get_dataloader(processor, args, dtype, tokenizer=None): logger.info("Fetching the dataset") train_dataset = load_dataset(cfg.dataset_id, split="train") train_collate_fn = partial( @@ -83,63 +107,97 @@ def train_model(model, optimizer, cfg:Configuration, train_dataloader): def load_model(cfg:Configuration): - bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=cfg.dtype, - ) - quant_args = { - "quantization_config": bnb_config, - "device_map": "auto", - } - - model = Gemma3ForConditionalGeneration.from_pretrained( - cfg.model_id, - torch_dtype=cfg.dtype, - attn_implementation="eager", - device_map="auto", - **quant_args, - ) + lcfg = cfg.lora + tokenizer = None - if cfg.finetune_method in {"lora", "qlora"}: - for n, p in model.named_parameters(): - p.requires_grad = False - - lcfg = cfg.lora - lora_cfg = LoraConfig( - r=lcfg.r, - lora_alpha=lcfg.alpha, - target_modules=lcfg.target_modules, - lora_dropout=lcfg.dropout, - bias="none", - task_type="CAUSAL_LM", + if cfg.use_unsloth and FastModel is not None: + + model, tokenizer = FastModel.from_pretrained( + model_name = "unsloth/gemma-3-4b-it", + max_seq_length = 2048, # Choose any for long context! + load_in_4bit = True, # 4 bit quantization to reduce memory + load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory + full_finetuning = False, # [NEW!] We have full finetuning now! + # token = "hf_...", # use one if using gated models ) - - model = get_peft_model(model, lora_cfg) - model.print_trainable_parameters() - torch.cuda.empty_cache() - - elif cfg.finetune_method == "FFT": - # Only unfreeze requested model parts (e.g. multi_modal_projector) - for n, p in model.named_parameters(): - p.requires_grad = any(part in n for part in cfg.mm_tunable_parts) - print(f"{n} will be finetuned") + + if cfg.finetune_method in {"lora", "qlora"}: + model = FastModel.get_peft_model( + model, + finetune_vision_layers = True if "vision" in lcfg.target_modules else False, # Turn off for just text! + finetune_language_layers = True, # Should leave on! + finetune_attention_modules = True, # Attention good for GRPO + finetune_mlp_modules = True, # SHould leave on always! + + r=lcfg.r, # Larger = higher accuracy, but might overfit + lora_alpha=lcfg.alpha, # Recommended alpha == r at least + lora_dropout=lcfg.dropout, + bias = "none", + random_state = 3407, + ) + + model.print_trainable_parameters() + + else: - raise ValueError(f"Unknown finetune_method: {cfg.finetune_method}") + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=cfg.dtype, + ) + quant_args = { + "quantization_config": bnb_config, + "device_map": "auto", + } + + model = Gemma3ForConditionalGeneration.from_pretrained( + cfg.model_id, + torch_dtype=cfg.dtype, + attn_implementation="eager", + **quant_args, + ) + + if cfg.finetune_method in {"lora", "qlora"}: + for n, p in model.named_parameters(): + p.requires_grad = False + + lora_cfg = LoraConfig( + r=lcfg.r, + lora_alpha=lcfg.alpha, + target_modules=lcfg.target_modules, + lora_dropout=lcfg.dropout, + bias="none", + ) + + model = get_peft_model(model, lora_cfg) + model.print_trainable_parameters() + torch.cuda.empty_cache() + + elif cfg.finetune_method == "FFT": + # Only unfreeze requested model parts (e.g. multi_modal_projector) + for n, p in model.named_parameters(): + p.requires_grad = any(part in n for part in cfg.mm_tunable_parts) + print(f"{n} will be finetuned") + else: + raise ValueError(f"Unknown finetune_method: {cfg.finetune_method}") - return model + return model, tokenizer if __name__ == "__main__": # 1. Parse CLI + YAMLs into config cfg = Configuration.from_args() - processor = AutoProcessor.from_pretrained(cfg.model_id) - train_dataloader = get_dataloader(processor=processor, args=cfg, dtype=cfg.dtype) - logger.info("Getting model & turning only attention parameters to trainable") - model = load_model(cfg) + model, tokenizer = load_model(cfg) + + if cfg.use_unsloth: + train_dataloader = get_dataloader_unsloth(tokenizer=tokenizer, args=cfg, dtype=cfg.dtype) + else: + processor = AutoProcessor.from_pretrained(cfg.model_id) + train_dataloader = get_dataloader(processor=processor, args=cfg, dtype=cfg.dtype) model.train() model.to(cfg.device) diff --git a/utils/utilities.py b/utils/utilities.py index 7fe3fde..1e8ecf4 100644 --- a/utils/utilities.py +++ b/utils/utilities.py @@ -51,6 +51,54 @@ def visualize_bounding_boxes(image, label, width, height, name): plt.show() plt.close() +def train_collate_function_unsloth(batch_of_samples, tokenizer, dtype, transform=None): + """ + unsloth + """ + images = [] + prompts = [] + for sample in batch_of_samples: + if transform: + transformed = transform( + image=np.array(sample["image"]), + bboxes=sample["objects"]["bbox"], + category_ids=sample["objects"]["category"] + ) + sample["image"] = transformed["image"] + sample["objects"]["bbox"] = transformed["bboxes"] + sample["objects"]["category"] = transformed["category_ids"] + sample["height"] = sample["image"].shape[0] + sample["width"] = sample["image"].shape[1] + sample['label_for_paligemma'] = format_objects(sample)['label_for_paligemma'] + images.append([sample["image"]]) + prompts.append( + f"{tokenizer.boi_token} detect \n\n{sample['label_for_paligemma']} {tokenizer.eos_token}" + ) + + # Use tokenizer directly (Unsloth tokenizer supports vision inputs for Gemma3) + batch = tokenizer( + images=images, + text=prompts, + return_tensors="pt", + padding=True + ) + + labels = batch["input_ids"].clone() + + # Mask out padding, image tokens, and other special tokens from loss + image_token_id = [ + tokenizer.convert_tokens_to_ids(tokenizer.boi_token) + ] + labels[labels == tokenizer.pad_token_id] = -100 + for tok_id in image_token_id: + labels[labels == tok_id] = -100 + labels[labels == 262144] = -100 # If this ID is used for your "unused" special token + + batch["labels"] = labels + if "pixel_values" in batch: + batch["pixel_values"] = batch["pixel_values"].to(dtype) + + return batch def train_collate_function(batch_of_samples, processor, dtype, transform=None): images = [] From 5c59f8ec4624d4dfe532f6184c2e48c832e4cc02 Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Mon, 16 Jun 2025 16:04:56 -0400 Subject: [PATCH 08/22] fixed tokenizer issue, made img size to small --- configs/config.yaml | 2 +- train.py | 13 +++++++------ utils/utilities.py | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/configs/config.yaml b/configs/config.yaml index b79ab8b..81f3b64 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -3,7 +3,7 @@ model_id: "google/gemma-3-4b-pt" checkpoint_id: "sergiopaniego/gemma-3-4b-pt-object-detection-aug" device: "cuda" -dtype: "float16" +dtype: "bfloat16" batch_size: 1 learning_rate: 2e-5 diff --git a/train.py b/train.py index 33401a1..dd70eca 100644 --- a/train.py +++ b/train.py @@ -1,8 +1,9 @@ # Optional – Unsloth is only imported if the flag is set at runtime -try: - from unsloth import FastModel -except ImportError: - FastModel = None # will be checked at runtime +# try: +# from unsloth import FastModel +# except ImportError: +# FastModel = None # will be checked at runtime +FastModel = None import logging import wandb @@ -27,7 +28,7 @@ augmentations = A.Compose([ - A.Resize(height=896, width=896), + A.Resize(height=224, width=224), A.HorizontalFlip(p=0.5), A.ColorJitter(p=0.2), ], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids'], filter_invalid_bboxes=True)) @@ -149,7 +150,7 @@ def load_model(cfg:Configuration): ) quant_args = { "quantization_config": bnb_config, - "device_map": "auto", + "device_map": "cpu", } model = Gemma3ForConditionalGeneration.from_pretrained( diff --git a/utils/utilities.py b/utils/utilities.py index 1e8ecf4..5381a91 100644 --- a/utils/utilities.py +++ b/utils/utilities.py @@ -87,7 +87,7 @@ def train_collate_function_unsloth(batch_of_samples, tokenizer, dtype, transform # Mask out padding, image tokens, and other special tokens from loss image_token_id = [ - tokenizer.convert_tokens_to_ids(tokenizer.boi_token) + tokenizer.tokenizer.convert_tokens_to_ids(tokenizer.boi_token) ] labels[labels == tokenizer.pad_token_id] = -100 for tok_id in image_token_id: From d0bdf5ec26bc88b0ba1b05011443126a1064b628 Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Mon, 16 Jun 2025 16:16:10 -0400 Subject: [PATCH 09/22] unsloth uncommented --- train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index dd70eca..061920f 100644 --- a/train.py +++ b/train.py @@ -1,9 +1,9 @@ # Optional – Unsloth is only imported if the flag is set at runtime -# try: -# from unsloth import FastModel -# except ImportError: -# FastModel = None # will be checked at runtime -FastModel = None +try: + from unsloth import FastModel +except ImportError: + FastModel = None # will be checked at runtime +# FastModel = None import logging import wandb @@ -150,7 +150,7 @@ def load_model(cfg:Configuration): ) quant_args = { "quantization_config": bnb_config, - "device_map": "cpu", + "device_map": "auto", } model = Gemma3ForConditionalGeneration.from_pretrained( From 1514fdbfae514e8d6e5429de00da96377e5f7c7a Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Mon, 16 Jun 2025 16:22:19 -0400 Subject: [PATCH 10/22] set quant type to fp4 --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 061920f..da87ae1 100644 --- a/train.py +++ b/train.py @@ -145,7 +145,7 @@ def load_model(cfg:Configuration): bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", + bnb_4bit_quant_type="fp4", bnb_4bit_compute_dtype=cfg.dtype, ) quant_args = { From 5fd67e411bc5922fc310d7da99de7cfdf3b6b9b4 Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Mon, 16 Jun 2025 16:23:47 -0400 Subject: [PATCH 11/22] cfg print --- train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/train.py b/train.py index da87ae1..16f3ba7 100644 --- a/train.py +++ b/train.py @@ -79,6 +79,9 @@ def train_model(model, optimizer, cfg:Configuration, train_dataloader): else: logger.info(f"Found dtype: {cfg.dtype}") + + logger.info(f"config : {vars(cfg)}") + logger.info("Start training") for epoch in range(cfg.epochs): for idx, batch in enumerate(train_dataloader): From edb273ef127bffdaf8228bc549fabe0d3b70a1aa Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Mon, 16 Jun 2025 16:42:44 -0400 Subject: [PATCH 12/22] wandb disable --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 16f3ba7..68199af 100644 --- a/train.py +++ b/train.py @@ -103,7 +103,7 @@ def train_model(model, optimizer, cfg:Configuration, train_dataloader): if idx % 100 == 0: logger.info(f"Epoch: {epoch} Iter: {idx} Loss: {loss.item():.4f}") - wandb.log({"train/loss": loss.item(), "epoch": epoch}, step=global_step) + # wandb.log({"train/loss": loss.item(), "epoch": epoch}, step=global_step) global_step += 1 return model From f9abfb09d60246768add893ccebe3f1f21260783 Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Mon, 16 Jun 2025 22:48:57 +0000 Subject: [PATCH 13/22] workig code for unsloth --- configs/config.yaml | 5 +- train.py | 125 ++++++++++++++++++++++++++------------------ utils/config.py | 5 ++ utils/utilities.py | 29 +++++++++- 4 files changed, 108 insertions(+), 56 deletions(-) diff --git a/configs/config.yaml b/configs/config.yaml index 81f3b64..dd07d7d 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -1,6 +1,6 @@ dataset_id: "ariG23498/license-detection-paligemma" model_id: "google/gemma-3-4b-pt" -checkpoint_id: "sergiopaniego/gemma-3-4b-pt-object-detection-aug" +checkpoint_id: "ajaymin28/Gemma3_ObjeDet" device: "cuda" dtype: "bfloat16" @@ -15,4 +15,5 @@ use_unsloth: false mm_tunable_parts: - multi_modal_projector # - vision_tower - # - language_model \ No newline at end of file + # - language_model +project_name: "Gemma3_LoRA" \ No newline at end of file diff --git a/train.py b/train.py index 68199af..e4574ff 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,4 @@ -# Optional – Unsloth is only imported if the flag is set at runtime +# Optional – comment this out if you are not planinng to use unsloth try: from unsloth import FastModel except ImportError: @@ -17,7 +17,7 @@ from transformers import BitsAndBytesConfig from utils.config import Configuration -from utils.utilities import train_collate_function, train_collate_function_unsloth +from utils.utilities import train_collate_function, train_collate_function_unsloth, save_best_model, push_to_hub from peft import get_peft_config, get_peft_model, LoraConfig import albumentations as A @@ -28,14 +28,14 @@ augmentations = A.Compose([ - A.Resize(height=224, width=224), - A.HorizontalFlip(p=0.5), + A.Resize(height=896, width=896), + # A.HorizontalFlip(p=0.5), # does this handle flipping box coordinates? A.ColorJitter(p=0.2), ], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids'], filter_invalid_bboxes=True)) -def get_dataloader_unsloth(tokenizer, args, dtype): +def get_dataloader_unsloth(tokenizer, args, dtype, split="train"): logger.info("Fetching the dataset") - train_dataset = load_dataset(args.dataset_id, split="train") # or cfg.dataset_id + train_dataset = load_dataset(args.dataset_id, split=split) # or cfg.dataset_id train_collate_fn = partial( train_collate_function_unsloth, tokenizer=tokenizer, # <- Use the Unsloth tokenizer @@ -52,9 +52,9 @@ def get_dataloader_unsloth(tokenizer, args, dtype): ) return train_dataloader -def get_dataloader(processor, args, dtype, tokenizer=None): +def get_dataloader(processor, args, dtype, tokenizer=None, split="train"): logger.info("Fetching the dataset") - train_dataset = load_dataset(cfg.dataset_id, split="train") + train_dataset = load_dataset(cfg.dataset_id, split=split) train_collate_fn = partial( train_collate_function, processor=processor, dtype=dtype, transform=augmentations ) @@ -65,47 +65,66 @@ def get_dataloader(processor, args, dtype, tokenizer=None): batch_size=args.batch_size, collate_fn=train_collate_fn, shuffle=True, + pin_memory=True, ) return train_dataloader - -def train_model(model, optimizer, cfg:Configuration, train_dataloader): - global_step = 0 - use_fp16 = False - if cfg.dtype in [torch.float16, torch.bfloat16]: - scaler = GradScaler() - use_fp16 = True - logger.info("using fp16 to scale loss") +def step(model, batch, device, use_fp16, optimizer=None, scaler=None): + data = batch.to(device) + if use_fp16: + with autocast(device_type=device): + loss = model(**data).loss else: - logger.info(f"Found dtype: {cfg.dtype}") - + loss = model(**data).loss + if optimizer: + optimizer.zero_grad() + if use_fp16: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + optimizer.step() + return loss.item() + +def validate_all(model, val_loader, device, use_fp16): + model.eval() + with torch.no_grad(): + n_batches = 2 + losses = [] + for i, batch in enumerate(val_loader): + if i >= n_batches: + break + losses.append(step(model, batch, device, use_fp16)) + + # losses = [step(model, batch, device, use_fp16) for batch in val_loader] + model.train() + return sum(losses) / len(losses) if len(losses)> 0 else 0 - logger.info(f"config : {vars(cfg)}") +def train_model(model, optimizer, cfg, train_loader, val_loader=None, val_every=5, push_hub=False): + use_fp16 = cfg.dtype in [torch.float16, torch.bfloat16] + scaler = GradScaler() if use_fp16 else None + global_step, best_val_loss = 0, float("inf") - logger.info("Start training") for epoch in range(cfg.epochs): - for idx, batch in enumerate(train_dataloader): - optimizer.zero_grad() # zero grad before every batch - - if use_fp16: - with autocast(device_type=cfg.device): - outputs = model(**batch.to(model.device)) - loss = outputs.loss - - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - else: - outputs = model(**batch.to(model.device)) - loss = outputs.loss - loss.backward() - optimizer.step() - - if idx % 100 == 0: - logger.info(f"Epoch: {epoch} Iter: {idx} Loss: {loss.item():.4f}") - # wandb.log({"train/loss": loss.item(), "epoch": epoch}, step=global_step) + for idx, batch in enumerate(train_loader): + loss = step(model, batch, cfg.device, use_fp16, optimizer, scaler) + if global_step % 1 == 0: + logger.info(f"Epoch:{epoch} Step:{global_step} Loss:{loss:.4f}") + wandb.log({"train/loss": loss, "epoch": epoch}, step=global_step) + if val_loader and global_step % val_every == 0: + val_loss = validate_all(model, val_loader, cfg.device, use_fp16) + logger.info(f"Step:{global_step} Val Loss:{val_loss:.4f}") + wandb.log({"val/loss": val_loss, "epoch": epoch}, step=global_step) global_step += 1 + if val_loss < best_val_loss: + best_val_loss = val_loss + save_best_model(model, cfg, tokenizer, cfg.finetune_method in {"lora", "qlora"}, logger) + if push_hub: + logger.info(f"Pushing to hub at: {cfg.checkpoint_id}") + push_to_hub(model, cfg, tokenizer, cfg.finetune_method in {"lora", "qlora"}) + return model @@ -128,7 +147,7 @@ def load_model(cfg:Configuration): if cfg.finetune_method in {"lora", "qlora"}: model = FastModel.get_peft_model( model, - finetune_vision_layers = True if "vision" in lcfg.target_modules else False, # Turn off for just text! + finetune_vision_layers = True, # Turn off for just text! finetune_language_layers = True, # Should leave on! finetune_attention_modules = True, # Attention good for GRPO finetune_mlp_modules = True, # SHould leave on always! @@ -187,6 +206,10 @@ def load_model(cfg:Configuration): else: raise ValueError(f"Unknown finetune_method: {cfg.finetune_method}") + for n, p in model.named_parameters(): + if p.requires_grad: + print(f"{n} will be finetuned") + return model, tokenizer @@ -199,9 +222,11 @@ def load_model(cfg:Configuration): if cfg.use_unsloth: train_dataloader = get_dataloader_unsloth(tokenizer=tokenizer, args=cfg, dtype=cfg.dtype) + validation_dataloader = get_dataloader_unsloth(tokenizer=tokenizer, args=cfg, dtype=cfg.dtype, split="validation") else: processor = AutoProcessor.from_pretrained(cfg.model_id) train_dataloader = get_dataloader(processor=processor, args=cfg, dtype=cfg.dtype) + validation_dataloader = get_dataloader(processor=processor, args=cfg, dtype=cfg.dtype, split="validation") model.train() model.to(cfg.device) @@ -210,17 +235,13 @@ def load_model(cfg:Configuration): params_to_train = list(filter(lambda x: x.requires_grad, model.parameters())) optimizer = torch.optim.AdamW(params_to_train, lr=cfg.learning_rate) - # wandb.init( - # project=cfg.project_name, - # name=cfg.run_name if hasattr(cfg, "run_name") else None, - # config=vars(cfg), - # ) - - train_model(model, optimizer, cfg, train_dataloader) + wandb.init( + project=cfg.project_name, + name=cfg.run_name if hasattr(cfg, "run_name") else None, + config=vars(cfg), + ) - # # Push the checkpoint to hub - # model.push_to_hub(cfg.checkpoint_id) - # processor.push_to_hub(cfg.checkpoint_id) + train_model(model, optimizer, cfg, train_dataloader, validation_dataloader, push_hub=True) - # wandb.finish() + wandb.finish() logger.info("Train finished") diff --git a/utils/config.py b/utils/config.py index 1665a9b..b9b5c45 100644 --- a/utils/config.py +++ b/utils/config.py @@ -42,6 +42,8 @@ class Configuration: mm_tunable_parts: List[str] = field(default_factory=lambda: ["multi_modal_projector"]) # vision_tower,language_model lora: LoRAConfig = field(default_factory=LoRAConfig) + project_name: str = "Gemma3_LoRA" + @classmethod def load(cls, main_cfg_path="configs/config.yaml", lora_cfg_path="configs/lora_config.yaml"): base_cfg = OmegaConf.load(main_cfg_path) @@ -74,6 +76,8 @@ def from_args(cls): parser.add_argument("--lora.target_modules", type=str, default=",".join(cfg_dict["lora"]["target_modules"])) parser.add_argument("--lora.max_seq_length", type=int, default=cfg_dict["lora"]["max_seq_length"]) + parser.add_argument("--wandb_project", type=str, default=cfg_dict["project_name"]) + args = parser.parse_args() dtype_map = { @@ -103,4 +107,5 @@ def from_args(cls): use_unsloth=args.use_unsloth, mm_tunable_parts=[x.strip() for x in args.mm_tunable_parts.split(',')], lora=lora_config, + project_name=args.wandb_project ) \ No newline at end of file diff --git a/utils/utilities.py b/utils/utilities.py index 5381a91..9abd34f 100644 --- a/utils/utilities.py +++ b/utils/utilities.py @@ -1,9 +1,10 @@ +import os import re import argparse import matplotlib.pyplot as plt import numpy as np from PIL import ImageDraw - +import torch from utils.create_dataset import format_objects def parse_paligemma_label(label, width, height): @@ -166,4 +167,28 @@ def str2bool(v): elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: - raise argparse.ArgumentTypeError('Boolean value expected.') \ No newline at end of file + raise argparse.ArgumentTypeError('Boolean value expected.') + +def push_to_hub(model, cfg, tokenizer=None, is_lora=False): + """ + Push model to huggingface + """ + push_kwargs = {} + if tokenizer is not None: + push_kwargs['tokenizer'] = tokenizer + model.push_to_hub(cfg.checkpoint_id, **push_kwargs) + if tokenizer is not None: + tokenizer.push_to_hub(cfg.checkpoint_id) + +def save_best_model(model, cfg, tokenizer=None, is_lora=False, logger=None): + """Save LoRA adapter or full model based on config.""" + save_path = f"checkpoints/{cfg.checkpoint_id}_best" + os.makedirs(save_path, exist_ok=True) + if is_lora: + if logger: logger.info(f"Saving LoRA adapter to {save_path}") + model.save_pretrained(save_path) + if tokenizer is not None: + tokenizer.save_pretrained(save_path) + else: + if logger: logger.info(f"Saving full model weights to {save_path}.pt") + torch.save(model.state_dict(), f"{save_path}.pt") \ No newline at end of file From 3ed763bd4565be646df5989f64e9337a65047177 Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Mon, 16 Jun 2025 23:16:48 +0000 Subject: [PATCH 14/22] final checkpoint for usloth training [uncleaned] --- configs/config.yaml | 11 +++++----- train.py | 51 ++++++++++++++++++++++++--------------------- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/configs/config.yaml b/configs/config.yaml index dd07d7d..266e038 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -5,15 +5,16 @@ checkpoint_id: "ajaymin28/Gemma3_ObjeDet" device: "cuda" dtype: "bfloat16" -batch_size: 1 +batch_size: 16 learning_rate: 2e-5 -epochs: 2 +epochs: 1 finetune_method: "qlora" # FFT | lora | qlora use_unsloth: false -mm_tunable_parts: + +mm_tunable_parts: # Only for the FFT - multi_modal_projector - # - vision_tower - # - language_model + - vision_tower + - language_model project_name: "Gemma3_LoRA" \ No newline at end of file diff --git a/train.py b/train.py index e4574ff..3d91446 100644 --- a/train.py +++ b/train.py @@ -87,17 +87,19 @@ def step(model, batch, device, use_fp16, optimizer=None, scaler=None): optimizer.step() return loss.item() -def validate_all(model, val_loader, device, use_fp16): +def validate_all(model, val_loader, device, use_fp16, val_bathes=None): model.eval() with torch.no_grad(): - n_batches = 2 - losses = [] - for i, batch in enumerate(val_loader): - if i >= n_batches: - break - losses.append(step(model, batch, device, use_fp16)) - - # losses = [step(model, batch, device, use_fp16) for batch in val_loader] + if val_bathes: + ## TODO: This logic is Temp and should be removed in final clean up + n_batches = 10 + losses = [] + for i, batch in enumerate(val_loader): + if i >= n_batches: + break + losses.append(step(model, batch, device, use_fp16)) + else: + losses = [step(model, batch, device, use_fp16) for batch in val_loader] model.train() return sum(losses) / len(losses) if len(losses)> 0 else 0 @@ -121,9 +123,6 @@ def train_model(model, optimizer, cfg, train_loader, val_loader=None, val_every= if val_loss < best_val_loss: best_val_loss = val_loss save_best_model(model, cfg, tokenizer, cfg.finetune_method in {"lora", "qlora"}, logger) - if push_hub: - logger.info(f"Pushing to hub at: {cfg.checkpoint_id}") - push_to_hub(model, cfg, tokenizer, cfg.finetune_method in {"lora", "qlora"}) return model @@ -163,17 +162,16 @@ def load_model(cfg:Configuration): else: - - bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="fp4", - bnb_4bit_compute_dtype=cfg.dtype, - ) - quant_args = { - "quantization_config": bnb_config, - "device_map": "auto", - } + quant_args = {} + # Enable quantization only for QLoRA or if specifically requested for LoRA + if cfg.finetune_method in {"lora", "qlora"}: + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="fp4", + bnb_4bit_compute_dtype=cfg.dtype, + ) + quant_args = {"quantization_config": bnb_config, "device_map": "auto"} model = Gemma3ForConditionalGeneration.from_pretrained( cfg.model_id, @@ -241,7 +239,12 @@ def load_model(cfg:Configuration): config=vars(cfg), ) - train_model(model, optimizer, cfg, train_dataloader, validation_dataloader, push_hub=True) + train_model(model, optimizer, cfg, train_dataloader, validation_dataloader,val_every=10, push_hub=True) + + # TODO add flag to config (code tested and its working) + # if push_hub: + # logger.info(f"Pushing to hub at: {cfg.checkpoint_id}") + # push_to_hub(model, cfg, tokenizer, cfg.finetune_method in {"lora", "qlora"}) wandb.finish() logger.info("Train finished") From 56b25d09589e57fdecd977ed42f1a1d3e73419a2 Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Mon, 16 Jun 2025 23:18:24 +0000 Subject: [PATCH 15/22] final checkpoint for usloth training [uncleaned] --- train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 3d91446..ccae0df 100644 --- a/train.py +++ b/train.py @@ -87,12 +87,12 @@ def step(model, batch, device, use_fp16, optimizer=None, scaler=None): optimizer.step() return loss.item() -def validate_all(model, val_loader, device, use_fp16, val_bathes=None): +def validate_all(model, val_loader, device, use_fp16, val_batches=5): model.eval() with torch.no_grad(): - if val_bathes: + if val_batches: ## TODO: This logic is Temp and should be removed in final clean up - n_batches = 10 + n_batches = val_batches losses = [] for i, batch in enumerate(val_loader): if i >= n_batches: From dcfd2d78809849f08c63c9be0a4eaa8f9ce9d3a9 Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Mon, 16 Jun 2025 23:28:12 +0000 Subject: [PATCH 16/22] lad best model back and push to hub --- train.py | 21 +++++++++++++-------- utils/utilities.py | 28 +++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/train.py b/train.py index ccae0df..e1566ab 100644 --- a/train.py +++ b/train.py @@ -17,7 +17,7 @@ from transformers import BitsAndBytesConfig from utils.config import Configuration -from utils.utilities import train_collate_function, train_collate_function_unsloth, save_best_model, push_to_hub +from utils.utilities import train_collate_function, train_collate_function_unsloth, save_best_model, push_to_hub, load_saved_model from peft import get_peft_config, get_peft_model, LoraConfig import albumentations as A @@ -103,7 +103,7 @@ def validate_all(model, val_loader, device, use_fp16, val_batches=5): model.train() return sum(losses) / len(losses) if len(losses)> 0 else 0 -def train_model(model, optimizer, cfg, train_loader, val_loader=None, val_every=5, push_hub=False): +def train_model(model, optimizer, cfg, train_loader, val_loader=None, val_every=5, max_step=10): use_fp16 = cfg.dtype in [torch.float16, torch.bfloat16] scaler = GradScaler() if use_fp16 else None global_step, best_val_loss = 0, float("inf") @@ -124,6 +124,9 @@ def train_model(model, optimizer, cfg, train_loader, val_loader=None, val_every= best_val_loss = val_loss save_best_model(model, cfg, tokenizer, cfg.finetune_method in {"lora", "qlora"}, logger) + if global_step>max_step: + break + return model @@ -164,7 +167,7 @@ def load_model(cfg:Configuration): else: quant_args = {} # Enable quantization only for QLoRA or if specifically requested for LoRA - if cfg.finetune_method in {"lora", "qlora"}: + if cfg.finetune_method in {"qlora"}: bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, @@ -181,8 +184,8 @@ def load_model(cfg:Configuration): ) if cfg.finetune_method in {"lora", "qlora"}: - for n, p in model.named_parameters(): - p.requires_grad = False + # for n, p in model.named_parameters(): + # p.requires_grad = False lora_cfg = LoraConfig( r=lcfg.r, @@ -239,12 +242,14 @@ def load_model(cfg:Configuration): config=vars(cfg), ) - train_model(model, optimizer, cfg, train_dataloader, validation_dataloader,val_every=10, push_hub=True) + train_model(model, optimizer, cfg, train_dataloader, validation_dataloader,val_every=5, max_step=10) + # Loading best model back + model, tokenizer = load_saved_model(cfg, is_lora=cfg.finetune_method in {"lora", "qlora"}, device="cuda", logger=logger) + logger.info(f"Pushing to hub at: {cfg.checkpoint_id}") # TODO add flag to config (code tested and its working) # if push_hub: - # logger.info(f"Pushing to hub at: {cfg.checkpoint_id}") - # push_to_hub(model, cfg, tokenizer, cfg.finetune_method in {"lora", "qlora"}) + push_to_hub(model, cfg, tokenizer, cfg.finetune_method in {"lora", "qlora"}) wandb.finish() logger.info("Train finished") diff --git a/utils/utilities.py b/utils/utilities.py index 9abd34f..3bfcd78 100644 --- a/utils/utilities.py +++ b/utils/utilities.py @@ -6,6 +6,8 @@ from PIL import ImageDraw import torch from utils.create_dataset import format_objects +from transformers import AutoModel, AutoTokenizer # Change to your model class if needed +from peft import PeftModel, PeftConfig def parse_paligemma_label(label, width, height): # Extract location codes @@ -191,4 +193,28 @@ def save_best_model(model, cfg, tokenizer=None, is_lora=False, logger=None): tokenizer.save_pretrained(save_path) else: if logger: logger.info(f"Saving full model weights to {save_path}.pt") - torch.save(model.state_dict(), f"{save_path}.pt") \ No newline at end of file + torch.save(model.state_dict(), f"{save_path}.pt") + + +def load_saved_model(cfg, is_lora=False, device=None, logger=None): + """ + Load LoRA adapter or full model based on config. + Returns (model, tokenizer) + """ + save_path = f"checkpoints/{cfg.checkpoint_id}_best" + tokenizer = None + + if is_lora: + if logger: logger.info(f"Loading LoRA adapter from {save_path}") + # Load base model first, then LoRA weights + base_model = AutoModel.from_pretrained(cfg.model_id, device_map=device or "auto") + model = PeftModel.from_pretrained(base_model, save_path, device_map=device or "auto") + if os.path.exists(os.path.join(save_path, "tokenizer_config.json")): + tokenizer = AutoTokenizer.from_pretrained(save_path) + else: + if logger: logger.info(f"Loading full model weights from {save_path}.pt") + model = AutoModel.from_pretrained(cfg.model_id, device_map=device or "auto") + model.load_state_dict(torch.load(f"{save_path}.pt", map_location=device or "cpu")) + if os.path.exists(os.path.join(save_path, "tokenizer_config.json")): + tokenizer = AutoTokenizer.from_pretrained(save_path) + return model, tokenizer \ No newline at end of file From c40dcbf3be05914e541229acd316bf91f529759d Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Mon, 16 Jun 2025 23:44:34 +0000 Subject: [PATCH 17/22] tested code train/val, save models --- train.py | 7 ++++--- utils/utilities.py | 38 ++++++++++++++++++++++++++------------ 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/train.py b/train.py index e1566ab..de4ec21 100644 --- a/train.py +++ b/train.py @@ -118,15 +118,16 @@ def train_model(model, optimizer, cfg, train_loader, val_loader=None, val_every= val_loss = validate_all(model, val_loader, cfg.device, use_fp16) logger.info(f"Step:{global_step} Val Loss:{val_loss:.4f}") wandb.log({"val/loss": val_loss, "epoch": epoch}, step=global_step) - global_step += 1 if val_loss < best_val_loss: best_val_loss = val_loss save_best_model(model, cfg, tokenizer, cfg.finetune_method in {"lora", "qlora"}, logger) - if global_step>max_step: + if global_step>max_step-1: break + global_step += 1 + return model @@ -242,7 +243,7 @@ def load_model(cfg:Configuration): config=vars(cfg), ) - train_model(model, optimizer, cfg, train_dataloader, validation_dataloader,val_every=5, max_step=10) + train_model(model, optimizer, cfg, train_dataloader, validation_dataloader,val_every=5, max_step=2) # Loading best model back model, tokenizer = load_saved_model(cfg, is_lora=cfg.finetune_method in {"lora", "qlora"}, device="cuda", logger=logger) diff --git a/utils/utilities.py b/utils/utilities.py index 3bfcd78..214a251 100644 --- a/utils/utilities.py +++ b/utils/utilities.py @@ -182,6 +182,8 @@ def push_to_hub(model, cfg, tokenizer=None, is_lora=False): if tokenizer is not None: tokenizer.push_to_hub(cfg.checkpoint_id) + + def save_best_model(model, cfg, tokenizer=None, is_lora=False, logger=None): """Save LoRA adapter or full model based on config.""" save_path = f"checkpoints/{cfg.checkpoint_id}_best" @@ -204,17 +206,29 @@ def load_saved_model(cfg, is_lora=False, device=None, logger=None): save_path = f"checkpoints/{cfg.checkpoint_id}_best" tokenizer = None - if is_lora: + if cfg.use_unsloth: + if logger: logger.info(f"Loading LoRA adapter from {save_path}") - # Load base model first, then LoRA weights - base_model = AutoModel.from_pretrained(cfg.model_id, device_map=device or "auto") - model = PeftModel.from_pretrained(base_model, save_path, device_map=device or "auto") - if os.path.exists(os.path.join(save_path, "tokenizer_config.json")): - tokenizer = AutoTokenizer.from_pretrained(save_path) + + from unsloth import FastModel + model, tokenizer = FastModel.from_pretrained( + model_name = save_path, # YOUR MODEL YOU USED FOR TRAINING + load_in_4bit = True, # Set to False for 16bit LoRA + ) + return model, tokenizer + else: - if logger: logger.info(f"Loading full model weights from {save_path}.pt") - model = AutoModel.from_pretrained(cfg.model_id, device_map=device or "auto") - model.load_state_dict(torch.load(f"{save_path}.pt", map_location=device or "cpu")) - if os.path.exists(os.path.join(save_path, "tokenizer_config.json")): - tokenizer = AutoTokenizer.from_pretrained(save_path) - return model, tokenizer \ No newline at end of file + if is_lora: + if logger: logger.info(f"Loading LoRA adapter from {save_path}") + # Load base model first, then LoRA weights + base_model = AutoModel.from_pretrained(cfg.model_id, device_map=device or "auto") + model = PeftModel.from_pretrained(base_model, save_path, device_map=device or "auto") + if os.path.exists(os.path.join(save_path, "tokenizer_config.json")): + tokenizer = AutoTokenizer.from_pretrained(save_path) + else: + if logger: logger.info(f"Loading full model weights from {save_path}.pt") + model = AutoModel.from_pretrained(cfg.model_id, device_map=device or "auto") + model.load_state_dict(torch.load(f"{save_path}.pt", map_location=device or "cpu")) + if os.path.exists(os.path.join(save_path, "tokenizer_config.json")): + tokenizer = AutoTokenizer.from_pretrained(save_path) + return model, tokenizer \ No newline at end of file From fd31e7c2951614253cc24a28506bea9aeab65ac0 Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Tue, 17 Jun 2025 10:03:08 -0400 Subject: [PATCH 18/22] cleaned code --- configs/lora_config.yaml | 4 +- train.py | 126 ++++++++++++++++++++++----------------- utils/config.py | 29 ++++++--- 3 files changed, 93 insertions(+), 66 deletions(-) diff --git a/configs/lora_config.yaml b/configs/lora_config.yaml index abc74fb..62e00c0 100644 --- a/configs/lora_config.yaml +++ b/configs/lora_config.yaml @@ -9,4 +9,6 @@ target_modules: - up_proj - down_proj - gate_proj -max_seq_length: 2048 # Unsloth will RoPE-scale \ No newline at end of file +max_seq_length: 2048 # Unsloth will RoPE-scale +load_in_4bit: true +load_in_8bit: false \ No newline at end of file diff --git a/train.py b/train.py index de4ec21..2672bd4 100644 --- a/train.py +++ b/train.py @@ -3,12 +3,12 @@ from unsloth import FastModel except ImportError: FastModel = None # will be checked at runtime -# FastModel = None +# FastModel = None # uncomment this line when commenting above lines import logging import wandb from functools import partial - +import os import torch from torch.amp import autocast, GradScaler from datasets import load_dataset @@ -29,35 +29,38 @@ augmentations = A.Compose([ A.Resize(height=896, width=896), - # A.HorizontalFlip(p=0.5), # does this handle flipping box coordinates? + A.HorizontalFlip(p=0.5), A.ColorJitter(p=0.2), ], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids'], filter_invalid_bboxes=True)) -def get_dataloader_unsloth(tokenizer, args, dtype, split="train"): - logger.info("Fetching the dataset") - train_dataset = load_dataset(args.dataset_id, split=split) # or cfg.dataset_id - train_collate_fn = partial( - train_collate_function_unsloth, - tokenizer=tokenizer, # <- Use the Unsloth tokenizer - dtype=dtype, - transform=augmentations - ) - - logger.info("Building data loader") - train_dataloader = DataLoader( - train_dataset, - batch_size=args.batch_size, - collate_fn=train_collate_fn, - shuffle=True, - ) - return train_dataloader - -def get_dataloader(processor, args, dtype, tokenizer=None, split="train"): +# def get_dataloader_unsloth(tokenizer, args, dtype, split="train"): +# logger.info("Fetching the dataset") +# train_dataset = load_dataset(args.dataset_id, split=split) # or cfg.dataset_id +# train_collate_fn = partial( +# train_collate_function_unsloth, +# tokenizer=tokenizer, # <- Use the Unsloth tokenizer instead of processor +# dtype=dtype, +# transform=augmentations +# ) + +# logger.info("Building data loader") +# train_dataloader = DataLoader( +# train_dataset, +# batch_size=args.batch_size, +# collate_fn=train_collate_fn, +# shuffle=True, +# ) +# return train_dataloader + +def get_dataloader(processor, args, dtype, split="train", is_unsloth=False): logger.info("Fetching the dataset") train_dataset = load_dataset(cfg.dataset_id, split=split) - train_collate_fn = partial( - train_collate_function, processor=processor, dtype=dtype, transform=augmentations - ) + + if is_unsloth: + # <- Use the Unsloth tokenizer instead of processor + train_collate_fn = partial(train_collate_function_unsloth,tokenizer=tokenizer,dtype=dtype,transform=augmentations) + else: + train_collate_fn = partial(train_collate_function, processor=processor, dtype=dtype, transform=augmentations) logger.info("Building data loader") train_dataloader = DataLoader( @@ -70,6 +73,9 @@ def get_dataloader(processor, args, dtype, tokenizer=None, split="train"): return train_dataloader def step(model, batch, device, use_fp16, optimizer=None, scaler=None): + """ + Single batch process + """ data = batch.to(device) if use_fp16: with autocast(device_type=device): @@ -87,7 +93,9 @@ def step(model, batch, device, use_fp16, optimizer=None, scaler=None): optimizer.step() return loss.item() -def validate_all(model, val_loader, device, use_fp16, val_batches=5): +def validate_all(model, val_loader, cfg, use_fp16,val_batches=5): + + model.eval() with torch.no_grad(): if val_batches: @@ -97,25 +105,31 @@ def validate_all(model, val_loader, device, use_fp16, val_batches=5): for i, batch in enumerate(val_loader): if i >= n_batches: break - losses.append(step(model, batch, device, use_fp16)) + losses.append(step(model, batch, cfg.device, use_fp16)) else: - losses = [step(model, batch, device, use_fp16) for batch in val_loader] + losses = [step(model, batch, cfg.device, use_fp16) for batch in val_loader] model.train() return sum(losses) / len(losses) if len(losses)> 0 else 0 -def train_model(model, optimizer, cfg, train_loader, val_loader=None, val_every=5, max_step=10): +def train_model(model, optimizer, cfg:Configuration, train_loader, val_loader=None): use_fp16 = cfg.dtype in [torch.float16, torch.bfloat16] scaler = GradScaler() if use_fp16 else None global_step, best_val_loss = 0, float("inf") + if cfg.use_unsloth and FastModel is not None: + FastModel.for_training(model) # Enable for inference! + else: + model.train() + model.to(cfg.device) + for epoch in range(cfg.epochs): for idx, batch in enumerate(train_loader): loss = step(model, batch, cfg.device, use_fp16, optimizer, scaler) if global_step % 1 == 0: logger.info(f"Epoch:{epoch} Step:{global_step} Loss:{loss:.4f}") wandb.log({"train/loss": loss, "epoch": epoch}, step=global_step) - if val_loader and global_step % val_every == 0: - val_loss = validate_all(model, val_loader, cfg.device, use_fp16) + if val_loader and global_step % cfg.validate_steps_freq == 0: + val_loss = validate_all(model, val_loader, cfg, use_fp16) logger.info(f"Step:{global_step} Val Loss:{val_loss:.4f}") wandb.log({"val/loss": val_loss, "epoch": epoch}, step=global_step) @@ -123,9 +137,9 @@ def train_model(model, optimizer, cfg, train_loader, val_loader=None, val_every= best_val_loss = val_loss save_best_model(model, cfg, tokenizer, cfg.finetune_method in {"lora", "qlora"}, logger) - if global_step>max_step-1: + ## Model seem to converge before even first epoch finishes for LoRA. set max_step_to_train<=0 to disable this. + if global_step>cfg.max_step_to_train-1 and cfg.max_step_to_train>0: break - global_step += 1 return model @@ -138,13 +152,14 @@ def load_model(cfg:Configuration): if cfg.use_unsloth and FastModel is not None: + # TODO: For LoRA and QLoRa change unsloth config accordigly, generally load_in_4bit, load_in_8bit will be False or LoRA model, tokenizer = FastModel.from_pretrained( model_name = "unsloth/gemma-3-4b-it", max_seq_length = 2048, # Choose any for long context! load_in_4bit = True, # 4 bit quantization to reduce memory load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory full_finetuning = False, # [NEW!] We have full finetuning now! - # token = "hf_...", # use one if using gated models + # token = os.environ["HF_TOKEN"] # TODO: Handle this ) if cfg.finetune_method in {"lora", "qlora"}: @@ -167,7 +182,7 @@ def load_model(cfg:Configuration): else: quant_args = {} - # Enable quantization only for QLoRA or if specifically requested for LoRA + # Enable quantization only for QLoRA if cfg.finetune_method in {"qlora"}: bnb_config = BitsAndBytesConfig( load_in_4bit=True, @@ -185,8 +200,6 @@ def load_model(cfg:Configuration): ) if cfg.finetune_method in {"lora", "qlora"}: - # for n, p in model.named_parameters(): - # p.requires_grad = False lora_cfg = LoraConfig( r=lcfg.r, @@ -198,17 +211,17 @@ def load_model(cfg:Configuration): model = get_peft_model(model, lora_cfg) model.print_trainable_parameters() - torch.cuda.empty_cache() + torch.cuda.empty_cache() # TODO: Do I need this? Just want to make sure I have mem cleaned up before training starts. elif cfg.finetune_method == "FFT": - # Only unfreeze requested model parts (e.g. multi_modal_projector) - for n, p in model.named_parameters(): - p.requires_grad = any(part in n for part in cfg.mm_tunable_parts) - print(f"{n} will be finetuned") + # handled below before printing params + pass else: raise ValueError(f"Unknown finetune_method: {cfg.finetune_method}") for n, p in model.named_parameters(): + if cfg.finetune_method == "FFT": + p.requires_grad = any(part in n for part in cfg.mm_tunable_parts) if p.requires_grad: print(f"{n} will be finetuned") @@ -217,40 +230,41 @@ def load_model(cfg:Configuration): if __name__ == "__main__": # 1. Parse CLI + YAMLs into config - cfg = Configuration.from_args() + cfg = Configuration.from_args() # config.yaml is overriden by CLI arguments - logger.info("Getting model & turning only attention parameters to trainable") + # 2. Load model + logger.info(f"Getting model for {cfg.finetune_method}") + # loads model based on config. Unsloth, lora, qlora, FFT model, tokenizer = load_model(cfg) + # 3. Get Data if cfg.use_unsloth: - train_dataloader = get_dataloader_unsloth(tokenizer=tokenizer, args=cfg, dtype=cfg.dtype) - validation_dataloader = get_dataloader_unsloth(tokenizer=tokenizer, args=cfg, dtype=cfg.dtype, split="validation") + train_dataloader = get_dataloader(tokenizer=tokenizer, args=cfg, dtype=cfg.dtype,split="train",is_unsloth=True) + validation_dataloader = get_dataloader(tokenizer=tokenizer, args=cfg, dtype=cfg.dtype, split="validation",is_unsloth=True) else: processor = AutoProcessor.from_pretrained(cfg.model_id) - train_dataloader = get_dataloader(processor=processor, args=cfg, dtype=cfg.dtype) + train_dataloader = get_dataloader(processor=processor, args=cfg, dtype=cfg.dtype, split="train") validation_dataloader = get_dataloader(processor=processor, args=cfg, dtype=cfg.dtype, split="validation") - model.train() - model.to(cfg.device) - # Credits to Sayak Paul for this beautiful expression params_to_train = list(filter(lambda x: x.requires_grad, model.parameters())) optimizer = torch.optim.AdamW(params_to_train, lr=cfg.learning_rate) + # 5. Enable logging, need to login or set wanddb token in os.env wandb.init( - project=cfg.project_name, + project=cfg.wandb_project_name, name=cfg.run_name if hasattr(cfg, "run_name") else None, config=vars(cfg), ) - train_model(model, optimizer, cfg, train_dataloader, validation_dataloader,val_every=5, max_step=2) + # 5. Actual train and validation, validation_dataloader=None to do just traing. + train_model(model, optimizer, cfg, train_dataloader, validation_dataloader) # Loading best model back model, tokenizer = load_saved_model(cfg, is_lora=cfg.finetune_method in {"lora", "qlora"}, device="cuda", logger=logger) logger.info(f"Pushing to hub at: {cfg.checkpoint_id}") - # TODO add flag to config (code tested and its working) - # if push_hub: - push_to_hub(model, cfg, tokenizer, cfg.finetune_method in {"lora", "qlora"}) + if cfg.push_model_to_hub: + push_to_hub(model, cfg, tokenizer, cfg.finetune_method in {"lora", "qlora"}) wandb.finish() logger.info("Train finished") diff --git a/utils/config.py b/utils/config.py index b9b5c45..21cd667 100644 --- a/utils/config.py +++ b/utils/config.py @@ -3,8 +3,6 @@ from dataclasses import dataclass, field from typing import List from omegaconf import OmegaConf -import os - def str2bool(v): if isinstance(v, bool): return v @@ -23,26 +21,29 @@ class LoRAConfig: "up_proj", "down_proj", "gate_proj" ]) max_seq_length: int = 2048 - + #QLoRA + load_in_4bit: bool = True + load_in_8bit: bool = False # more precise bet takes more mem + @dataclass class Configuration: dataset_id: str = "ariG23498/license-detection-paligemma" model_id: str = "google/gemma-3-4b-pt" checkpoint_id: str = "sergiopaniego/gemma-3-4b-pt-object-detection-aug" + push_model_to_hub: bool = False device: str = "cuda" if torch.cuda.is_available() else "cpu" dtype: torch.dtype = torch.bfloat16 - + validate_steps_freq: int = 500 batch_size: int = 16 learning_rate: float = 2e-5 epochs: int = 2 - + max_step_to_train: int = 5000 # if model converges before training one epoch, set to 0 or -1 to disable finetune_method: str = "FFT" # FFT | lora | qlora use_unsloth: bool = False mm_tunable_parts: List[str] = field(default_factory=lambda: ["multi_modal_projector"]) # vision_tower,language_model lora: LoRAConfig = field(default_factory=LoRAConfig) - - project_name: str = "Gemma3_LoRA" + wandb_project_name: str = "Gemma3_LoRA" @classmethod def load(cls, main_cfg_path="configs/config.yaml", lora_cfg_path="configs/lora_config.yaml"): @@ -60,11 +61,14 @@ def from_args(cls): parser.add_argument("--dataset_id", type=str, default=cfg_dict["dataset_id"]) parser.add_argument("--model_id", type=str, default=cfg_dict["model_id"]) parser.add_argument("--checkpoint_id", type=str, default=cfg_dict["checkpoint_id"]) + parser.add_argument("--push_model_to_hub", type=str2bool, default=cfg_dict["push_model_to_hub"]) parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default=cfg_dict["device"]) parser.add_argument("--dtype", type=str, choices=["float32", "float16", "bfloat16"], default="float16") parser.add_argument("--batch_size", type=int, default=cfg_dict["batch_size"]) parser.add_argument("--learning_rate", type=float, default=cfg_dict["learning_rate"]) parser.add_argument("--epochs", type=int, default=cfg_dict["epochs"]) + parser.add_argument("--max_step_to_train", type=int, default=cfg_dict["max_step_to_train"]) + parser.add_argument("--validate_steps_freq", type=int, default=cfg_dict["validate_steps_freq"]) parser.add_argument("--finetune_method", type=str, choices=["FFT", "lora", "qlora"], default=cfg_dict["finetune_method"]) parser.add_argument("--use_unsloth", type=str2bool, default=cfg_dict["use_unsloth"]) parser.add_argument("--mm_tunable_parts", type=str, default=",".join(cfg_dict["mm_tunable_parts"])) @@ -75,8 +79,10 @@ def from_args(cls): parser.add_argument("--lora.dropout", type=float, default=cfg_dict["lora"]["dropout"]) parser.add_argument("--lora.target_modules", type=str, default=",".join(cfg_dict["lora"]["target_modules"])) parser.add_argument("--lora.max_seq_length", type=int, default=cfg_dict["lora"]["max_seq_length"]) + parser.add_argument("--lora.load_in_4bit", type=str2bool,default=cfg_dict["lora"]["load_in_4bit"]) + parser.add_argument("--lora.load_in_8bit", type=str2bool,default=cfg_dict["lora"]["load_in_8bit"]) - parser.add_argument("--wandb_project", type=str, default=cfg_dict["project_name"]) + parser.add_argument("--wandb_project_name", type=str, default=cfg_dict["project_name"]) args = parser.parse_args() @@ -92,6 +98,8 @@ def from_args(cls): dropout=args.__dict__["lora.dropout"], target_modules=[x.strip() for x in args.__dict__["lora.target_modules"].split(',')], max_seq_length=args.__dict__["lora.max_seq_length"], + load_in_4bit=args.__dict__["lora.load_in_4bit"], + load_in_8bit=args.__dict__["lora.load_in_8bit"], ) return cls( @@ -107,5 +115,8 @@ def from_args(cls): use_unsloth=args.use_unsloth, mm_tunable_parts=[x.strip() for x in args.mm_tunable_parts.split(',')], lora=lora_config, - project_name=args.wandb_project + wandb_project_name=args.wandb_project_name, + max_step_to_train=args.max_step_to_train, + push_model_to_hub=args.push_model_to_hub, + validate_steps_freq=args.validate_steps_freq ) \ No newline at end of file From 0edd0c6ea6bfa23497381a6bda25791a074a7f05 Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Tue, 17 Jun 2025 12:34:47 -0400 Subject: [PATCH 19/22] added TODOs for cleanup and new features --- train.py | 42 ++++++++++++++++++++++++++---------------- utils/config.py | 9 +++++---- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/train.py b/train.py index 2672bd4..4cff275 100644 --- a/train.py +++ b/train.py @@ -1,14 +1,20 @@ -# Optional – comment this out if you are not planinng to use unsloth -try: - from unsloth import FastModel -except ImportError: - FastModel = None # will be checked at runtime -# FastModel = None # uncomment this line when commenting above lines import logging + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +FastModel = None +# Optional – comment below imports if you are not planinng to use unsloth +try: from unsloth import FastModel +except ImportError as e: logger.log(f"Unsloth import error : {e}") +except NotImplementedError as e: logger.log(f"Unsloth NotImplementedError error : {e}") + + import wandb from functools import partial -import os import torch from torch.amp import autocast, GradScaler from datasets import load_dataset @@ -17,15 +23,11 @@ from transformers import BitsAndBytesConfig from utils.config import Configuration -from utils.utilities import train_collate_function, train_collate_function_unsloth, save_best_model, push_to_hub, load_saved_model +from utils.utilities import train_collate_function, train_collate_function_unsloth +from utils.utilities import save_best_model, push_to_hub, load_saved_model from peft import get_peft_config, get_peft_model, LoraConfig import albumentations as A -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) -logger = logging.getLogger(__name__) - augmentations = A.Compose([ A.Resize(height=896, width=896), @@ -33,6 +35,7 @@ A.ColorJitter(p=0.2), ], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids'], filter_invalid_bboxes=True)) +# TODO: Delete this after testing get_dataloader() with is_unsloth=True flag # def get_dataloader_unsloth(tokenizer, args, dtype, split="train"): # logger.info("Fetching the dataset") # train_dataset = load_dataset(args.dataset_id, split=split) # or cfg.dataset_id @@ -95,8 +98,11 @@ def step(model, batch, device, use_fp16, optimizer=None, scaler=None): def validate_all(model, val_loader, cfg, use_fp16,val_batches=5): + if cfg.use_unsloth and FastModel is not None: + FastModel.for_inference(model) # Enable for inference! + else: + model.eval() - model.eval() with torch.no_grad(): if val_batches: ## TODO: This logic is Temp and should be removed in final clean up @@ -129,7 +135,7 @@ def train_model(model, optimizer, cfg:Configuration, train_loader, val_loader=No logger.info(f"Epoch:{epoch} Step:{global_step} Loss:{loss:.4f}") wandb.log({"train/loss": loss, "epoch": epoch}, step=global_step) if val_loader and global_step % cfg.validate_steps_freq == 0: - val_loss = validate_all(model, val_loader, cfg, use_fp16) + val_loss = validate_all(model, val_loader, cfg, use_fp16, val_batches=5) # TODO, disable val_batches in final commit/run logger.info(f"Step:{global_step} Val Loss:{val_loss:.4f}") wandb.log({"val/loss": val_loss, "epoch": epoch}, step=global_step) @@ -260,11 +266,15 @@ def load_model(cfg:Configuration): # 5. Actual train and validation, validation_dataloader=None to do just traing. train_model(model, optimizer, cfg, train_dataloader, validation_dataloader) - # Loading best model back + # 6. Loading best model back model, tokenizer = load_saved_model(cfg, is_lora=cfg.finetune_method in {"lora", "qlora"}, device="cuda", logger=logger) logger.info(f"Pushing to hub at: {cfg.checkpoint_id}") if cfg.push_model_to_hub: push_to_hub(model, cfg, tokenizer, cfg.finetune_method in {"lora", "qlora"}) + # 7. Test? # TODO + + + # 8. Wrap up wandb.finish() logger.info("Train finished") diff --git a/utils/config.py b/utils/config.py index 21cd667..17efd93 100644 --- a/utils/config.py +++ b/utils/config.py @@ -12,7 +12,7 @@ def str2bool(v): @dataclass -class LoRAConfig: +class UserLoRAConfig: r: int = 32 alpha: int = 64 dropout: float = 0.05 @@ -42,13 +42,13 @@ class Configuration: finetune_method: str = "FFT" # FFT | lora | qlora use_unsloth: bool = False mm_tunable_parts: List[str] = field(default_factory=lambda: ["multi_modal_projector"]) # vision_tower,language_model - lora: LoRAConfig = field(default_factory=LoRAConfig) + lora: UserLoRAConfig = field(default_factory=UserLoRAConfig) wandb_project_name: str = "Gemma3_LoRA" @classmethod def load(cls, main_cfg_path="configs/config.yaml", lora_cfg_path="configs/lora_config.yaml"): base_cfg = OmegaConf.load(main_cfg_path) - lora_cfg = OmegaConf.load(lora_cfg_path) + lora_cfg = OmegaConf.load(lora_cfg_path) # TODO: Merge config into one, refer to hydra config. base_cfg.lora = lora_cfg return OmegaConf.to_container(base_cfg, resolve=True) @@ -92,7 +92,7 @@ def from_args(cls): "bfloat16": torch.bfloat16, } - lora_config = LoRAConfig( + lora_config = UserLoRAConfig( r=args.__dict__["lora.r"], alpha=args.__dict__["lora.alpha"], dropout=args.__dict__["lora.dropout"], @@ -102,6 +102,7 @@ def from_args(cls): load_in_8bit=args.__dict__["lora.load_in_8bit"], ) + # TODO handle this long list, probably migrate to hydra conf. return cls( dataset_id=args.dataset_id, model_id=args.model_id, From ef628bb4873c70aab531dbc948f651400bdd974a Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Tue, 17 Jun 2025 21:38:45 +0000 Subject: [PATCH 20/22] working code for unsloth qlora, vanilla qlora on l4 --- configs/config.yaml | 17 ++-- configs/lora_config.yaml | 9 +- configs/qlora_config.yaml | 20 ++++ train.py | 188 ++++++++++++++++++++++++++++++-------- utils/config.py | 16 ++-- 5 files changed, 195 insertions(+), 55 deletions(-) create mode 100644 configs/qlora_config.yaml diff --git a/configs/config.yaml b/configs/config.yaml index 266e038..27d11a5 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -1,5 +1,5 @@ dataset_id: "ariG23498/license-detection-paligemma" -model_id: "google/gemma-3-4b-pt" +model_id: "unsloth/gemma-3-4b-it" #"google/gemma-3-4b-pt" checkpoint_id: "ajaymin28/Gemma3_ObjeDet" device: "cuda" @@ -8,13 +8,18 @@ dtype: "bfloat16" batch_size: 16 learning_rate: 2e-5 epochs: 1 +max_step_to_train: 100 +validate_steps_freq: 10 finetune_method: "qlora" # FFT | lora | qlora use_unsloth: false -mm_tunable_parts: # Only for the FFT - - multi_modal_projector - - vision_tower - - language_model -project_name: "Gemma3_LoRA" \ No newline at end of file +mm_tunable_parts: + - no_exist_layer # basically not finetuning any base components + # - mlp + # - multi_modal_projector + # - vision_tower + # - language_model +wandb_project_name: "Gemma3_LoRA" +push_model_to_hub: true \ No newline at end of file diff --git a/configs/lora_config.yaml b/configs/lora_config.yaml index 62e00c0..b4943c5 100644 --- a/configs/lora_config.yaml +++ b/configs/lora_config.yaml @@ -1,5 +1,5 @@ r: 32 -alpha: 64 +alpha: 32 dropout: 0.05 target_modules: - q_proj @@ -10,5 +10,8 @@ target_modules: - down_proj - gate_proj max_seq_length: 2048 # Unsloth will RoPE-scale -load_in_4bit: true -load_in_8bit: false \ No newline at end of file + +# LoRA-specific: no quantization +load_in_4bit: false +load_in_8bit: false +quantization_config: null \ No newline at end of file diff --git a/configs/qlora_config.yaml b/configs/qlora_config.yaml new file mode 100644 index 0000000..dc4d134 --- /dev/null +++ b/configs/qlora_config.yaml @@ -0,0 +1,20 @@ +r: 32 +alpha: 32 +dropout: 0.05 +target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - up_proj + - down_proj + - gate_proj +max_seq_length: 2048 # Unsloth will RoPE-scale + +# QLoRA-specific: quantization enabled +load_in_4bit: true +load_in_8bit: false +quantization_config: + bnb_4bit_use_double_quant: true + bnb_4bit_quant_type: "nf4" + bnb_4bit_compute_dtype: "bfloat16" diff --git a/train.py b/train.py index 4cff275..08a2b0e 100644 --- a/train.py +++ b/train.py @@ -8,10 +8,10 @@ FastModel = None # Optional – comment below imports if you are not planinng to use unsloth -try: from unsloth import FastModel -except ImportError as e: logger.log(f"Unsloth import error : {e}") -except NotImplementedError as e: logger.log(f"Unsloth NotImplementedError error : {e}") - +# try: from unsloth import FastModel +# except ImportError as e: logger.warning(f"Unsloth import error : {e}") +# except NotImplementedError as e: logger.warning(f"Unsloth NotImplementedError error : {e}") + import wandb from functools import partial @@ -26,6 +26,7 @@ from utils.utilities import train_collate_function, train_collate_function_unsloth from utils.utilities import save_best_model, push_to_hub, load_saved_model from peft import get_peft_config, get_peft_model, LoraConfig +from peft import prepare_model_for_kbit_training import albumentations as A @@ -55,15 +56,15 @@ # ) # return train_dataloader -def get_dataloader(processor, args, dtype, split="train", is_unsloth=False): - logger.info("Fetching the dataset") +def get_dataloader(args:Configuration,processor=None,tokenizer=None, split="train", is_unsloth=False): + logger.info(f"Fetching the dataset: {cfg.dataset_id}:{split}") train_dataset = load_dataset(cfg.dataset_id, split=split) if is_unsloth: # <- Use the Unsloth tokenizer instead of processor - train_collate_fn = partial(train_collate_function_unsloth,tokenizer=tokenizer,dtype=dtype,transform=augmentations) + train_collate_fn = partial(train_collate_function_unsloth,tokenizer=tokenizer,dtype=args.dtype,transform=augmentations) else: - train_collate_fn = partial(train_collate_function, processor=processor, dtype=dtype, transform=augmentations) + train_collate_fn = partial(train_collate_function, processor=processor, dtype=args.dtype, transform=augmentations) logger.info("Building data loader") train_dataloader = DataLoader( @@ -98,13 +99,14 @@ def step(model, batch, device, use_fp16, optimizer=None, scaler=None): def validate_all(model, val_loader, cfg, use_fp16,val_batches=5): - if cfg.use_unsloth and FastModel is not None: - FastModel.for_inference(model) # Enable for inference! - else: - model.eval() + # if cfg.use_unsloth and FastModel is not None: + # FastModel.for_inference(model) # Enable for inference! + # else: + # model.eval() + model.eval() with torch.no_grad(): - if val_batches: + if val_batches>0: ## TODO: This logic is Temp and should be removed in final clean up n_batches = val_batches losses = [] @@ -117,27 +119,121 @@ def validate_all(model, val_loader, cfg, use_fp16,val_batches=5): model.train() return sum(losses) / len(losses) if len(losses)> 0 else 0 +import psutil +import os + +def memory_stats(get_dict=False, print_mem_usage=True, device=None): + stats = { + "cpu": "", + "ram": "", + "cuda_free": "", + "cuda_total": "", + "cuda_allocated": "", + "cuda_reserved": "", + "peak_vram_allocated_mb": "", + } + + cuda_freeMem = 0 + cuda_total = 0 + cuda_allocated = 0 + cuda_reserved = 0 + peak_vram_allocated_bytes = 0 + peak_vram_allocated_mb = 0 + + if torch.cuda.is_available(): + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + try: + cuda_freeMem, cuda_total = torch.cuda.mem_get_info() + cuda_total = cuda_total/1024**2 + cuda_freeMem = cuda_freeMem/1024**2 + except: pass + + try: + cuda_allocated = torch.cuda.memory_allocated()/1024**2 + cuda_reserved = torch.cuda.memory_reserved()/1024**2 + except: pass + + try: + peak_vram_allocated_bytes = torch.cuda.max_memory_allocated(device) + peak_vram_allocated_mb = peak_vram_allocated_bytes / (1024 ** 2) + except: pass + + stats["cuda_free"] = cuda_freeMem + stats["cuda_total"] = cuda_total + stats["cuda_allocated"] = round(cuda_allocated,3) + stats["cuda_reserved"] = round(cuda_reserved,3) + stats["peak_vram_allocated_mb"] = round(peak_vram_allocated_mb,3) + + process = psutil.Process(os.getpid()) + ram_mem_perc = process.memory_percent() + cpu_usage = psutil.cpu_percent() + + stats["cpu"] = cpu_usage + stats["ram"] = ram_mem_perc + + if print_mem_usage: + logger.info(f"CPU: {cpu_usage:.2f}% RAM: {ram_mem_perc:.2f}% GPU memory Total: [{cuda_total:.2f}] Available: [{cuda_freeMem:.2f}] Allocated: [{cuda_allocated:.2f}] Reserved: [{cuda_reserved:.2f}] Cuda Peak Mem: {peak_vram_allocated_mb:.2f}") + + if get_dict: + return stats + def train_model(model, optimizer, cfg:Configuration, train_loader, val_loader=None): + + memory_stats() + torch.cuda.empty_cache() # TODO: Do I need this? Just want to make sure I have mem cleaned up before training starts. + logger.info(f"called: torch.cuda.empty_cache()") + memory_stats() use_fp16 = cfg.dtype in [torch.float16, torch.bfloat16] scaler = GradScaler() if use_fp16 else None global_step, best_val_loss = 0, float("inf") - if cfg.use_unsloth and FastModel is not None: - FastModel.for_training(model) # Enable for inference! - else: - model.train() - model.to(cfg.device) + # total_trainable = 0 + # total_params = 0 + # for n, p in model.named_parameters(): + # total_params += p.numel() + # if p.requires_grad: + # total_trainable += p.numel() + # logger.info(f"Total trainable parameters before train(): {total_trainable:,}") + + # if cfg.use_unsloth and FastModel is not None: + # # logger.info("Before setting for training...") + # # model.print_trainable_parameters() + # FastModel.for_training(model, use_gradient_checkpointing=True) # Enable for training! # TODO :calling this method uses so much memory, investigate + # else: + # model.train() + # model.to(cfg.device) + + model.train() + model.to(cfg.device) + + # total_trainable = 0 + # total_params = 0 + # for n, p in model.named_parameters(): + # total_params += p.numel() + # if p.requires_grad: + # total_trainable += p.numel() + # logger.info(f"Total trainable parameters after train(): {total_trainable:,}") + + logger.info("after setting for training...") + # model.print_trainable_parameters() + for epoch in range(cfg.epochs): for idx, batch in enumerate(train_loader): + + torch.cuda.reset_peak_memory_stats(cfg.device) + torch.cuda.empty_cache() + loss = step(model, batch, cfg.device, use_fp16, optimizer, scaler) if global_step % 1 == 0: logger.info(f"Epoch:{epoch} Step:{global_step} Loss:{loss:.4f}") - wandb.log({"train/loss": loss, "epoch": epoch}, step=global_step) + # wandb.log({"train/loss": loss, "epoch": epoch}, step=global_step) if val_loader and global_step % cfg.validate_steps_freq == 0: - val_loss = validate_all(model, val_loader, cfg, use_fp16, val_batches=5) # TODO, disable val_batches in final commit/run + val_loss = validate_all(model, val_loader, cfg, use_fp16, val_batches=1) # if val_batches>0 the code will validate on that many batches only. -1 to disable this logger.info(f"Step:{global_step} Val Loss:{val_loss:.4f}") - wandb.log({"val/loss": val_loss, "epoch": epoch}, step=global_step) + # wandb.log({"val/loss": val_loss, "epoch": epoch}, step=global_step) if val_loss < best_val_loss: best_val_loss = val_loss @@ -147,6 +243,11 @@ def train_model(model, optimizer, cfg:Configuration, train_loader, val_loader=No if global_step>cfg.max_step_to_train-1 and cfg.max_step_to_train>0: break global_step += 1 + if global_step % 5 == 0: + memory_stats() + + if global_step % 5 == 0: + memory_stats() return model @@ -180,11 +281,10 @@ def load_model(cfg:Configuration): lora_alpha=lcfg.alpha, # Recommended alpha == r at least lora_dropout=lcfg.dropout, bias = "none", - random_state = 3407, + random_state = 3407 + # TODO add rs_lora and dora ) - model.print_trainable_parameters() - else: quant_args = {} @@ -193,7 +293,7 @@ def load_model(cfg:Configuration): bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="fp4", + bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=cfg.dtype, ) quant_args = {"quantization_config": bnb_config, "device_map": "auto"} @@ -205,19 +305,28 @@ def load_model(cfg:Configuration): **quant_args, ) + if cfg.finetune_method in {"lora", "qlora"}: + if cfg.finetune_method=="qlora": + model = prepare_model_for_kbit_training(model) + + lora_cfg = LoraConfig( r=lcfg.r, lora_alpha=lcfg.alpha, target_modules=lcfg.target_modules, lora_dropout=lcfg.dropout, bias="none", + use_dora=True if cfg.finetune_method=="qlora" else False, + use_rslora=True # Rank-Stabilized LoRA --> `lora_alpha/math.sqrt(r)` ) model = get_peft_model(model, lora_cfg) - model.print_trainable_parameters() + memory_stats() torch.cuda.empty_cache() # TODO: Do I need this? Just want to make sure I have mem cleaned up before training starts. + logger.info(f"called: torch.cuda.empty_cache()") + memory_stats() elif cfg.finetune_method == "FFT": # handled below before printing params @@ -226,10 +335,13 @@ def load_model(cfg:Configuration): raise ValueError(f"Unknown finetune_method: {cfg.finetune_method}") for n, p in model.named_parameters(): - if cfg.finetune_method == "FFT": + if cfg.finetune_method == "FFT": # TODO: should FFT finetune all components? or just some, change FFT name to just FT? p.requires_grad = any(part in n for part in cfg.mm_tunable_parts) if p.requires_grad: print(f"{n} will be finetuned") + + if cfg.finetune_method in {"lora", "qlora"}: + model.print_trainable_parameters() return model, tokenizer @@ -245,23 +357,23 @@ def load_model(cfg:Configuration): # 3. Get Data if cfg.use_unsloth: - train_dataloader = get_dataloader(tokenizer=tokenizer, args=cfg, dtype=cfg.dtype,split="train",is_unsloth=True) - validation_dataloader = get_dataloader(tokenizer=tokenizer, args=cfg, dtype=cfg.dtype, split="validation",is_unsloth=True) + train_dataloader = get_dataloader(args=cfg,tokenizer=tokenizer, split="train",is_unsloth=True) + validation_dataloader = get_dataloader(args=cfg,tokenizer=tokenizer, split="validation",is_unsloth=True) else: processor = AutoProcessor.from_pretrained(cfg.model_id) - train_dataloader = get_dataloader(processor=processor, args=cfg, dtype=cfg.dtype, split="train") - validation_dataloader = get_dataloader(processor=processor, args=cfg, dtype=cfg.dtype, split="validation") + train_dataloader = get_dataloader(args=cfg, processor=processor, split="train") + validation_dataloader = get_dataloader(args=cfg, processor=processor, split="validation") # Credits to Sayak Paul for this beautiful expression params_to_train = list(filter(lambda x: x.requires_grad, model.parameters())) optimizer = torch.optim.AdamW(params_to_train, lr=cfg.learning_rate) - # 5. Enable logging, need to login or set wanddb token in os.env - wandb.init( - project=cfg.wandb_project_name, - name=cfg.run_name if hasattr(cfg, "run_name") else None, - config=vars(cfg), - ) + # # 5. Enable logging, need to login or set wanddb token in os.env + # wandb.init( + # project=cfg.wandb_project_name, + # name=cfg.run_name if hasattr(cfg, "run_name") else None, + # config=vars(cfg), + # ) # 5. Actual train and validation, validation_dataloader=None to do just traing. train_model(model, optimizer, cfg, train_dataloader, validation_dataloader) @@ -276,5 +388,5 @@ def load_model(cfg:Configuration): # 8. Wrap up - wandb.finish() + # wandb.finish() logger.info("Train finished") diff --git a/utils/config.py b/utils/config.py index 17efd93..72dd573 100644 --- a/utils/config.py +++ b/utils/config.py @@ -14,7 +14,7 @@ def str2bool(v): @dataclass class UserLoRAConfig: r: int = 32 - alpha: int = 64 + alpha: int = 32 dropout: float = 0.05 target_modules: List[str] = field(default_factory=lambda: [ "q_proj", "k_proj", "v_proj", "o_proj", @@ -29,17 +29,17 @@ class UserLoRAConfig: @dataclass class Configuration: dataset_id: str = "ariG23498/license-detection-paligemma" - model_id: str = "google/gemma-3-4b-pt" - checkpoint_id: str = "sergiopaniego/gemma-3-4b-pt-object-detection-aug" + model_id: str = "unsloth/gemma-3-4b-it" #"google/gemma-3-4b-pt" + checkpoint_id: str = "ajaymin28/Gemma3_ObjeDet" push_model_to_hub: bool = False device: str = "cuda" if torch.cuda.is_available() else "cpu" dtype: torch.dtype = torch.bfloat16 - validate_steps_freq: int = 500 + validate_steps_freq: int = 10 batch_size: int = 16 learning_rate: float = 2e-5 - epochs: int = 2 - max_step_to_train: int = 5000 # if model converges before training one epoch, set to 0 or -1 to disable - finetune_method: str = "FFT" # FFT | lora | qlora + epochs: int = 1 + max_step_to_train: int = 100 # if model converges before training one epoch, set to 0 or -1 to disable + finetune_method: str = "lora" # FFT | lora | qlora use_unsloth: bool = False mm_tunable_parts: List[str] = field(default_factory=lambda: ["multi_modal_projector"]) # vision_tower,language_model lora: UserLoRAConfig = field(default_factory=UserLoRAConfig) @@ -82,7 +82,7 @@ def from_args(cls): parser.add_argument("--lora.load_in_4bit", type=str2bool,default=cfg_dict["lora"]["load_in_4bit"]) parser.add_argument("--lora.load_in_8bit", type=str2bool,default=cfg_dict["lora"]["load_in_8bit"]) - parser.add_argument("--wandb_project_name", type=str, default=cfg_dict["project_name"]) + parser.add_argument("--wandb_project_name", type=str, default=cfg_dict["wandb_project_name"]) args = parser.parse_args() From de709105f32927277e746af7a8ff3804238af69b Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Tue, 17 Jun 2025 21:44:23 +0000 Subject: [PATCH 21/22] updated requirements --- requirements.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/requirements.txt b/requirements.txt index 84ddc32..c0f4b21 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,6 @@ wandb peft albumentations argparse +omegaconf +unsloth==2025.5.7 +unsloth-zoo==2025.5.8 \ No newline at end of file From d5fd4e07ca6d6c61aeb976972c56fc98c6f57a73 Mon Sep 17 00:00:00 2001 From: Jaimin Bhoi Date: Tue, 17 Jun 2025 21:52:54 +0000 Subject: [PATCH 22/22] cleaned code, enable wandb, tested google gemma model --- configs/config.yaml | 2 +- train.py | 52 ++++++++------------------------------------- utils/config.py | 2 +- 3 files changed, 11 insertions(+), 45 deletions(-) diff --git a/configs/config.yaml b/configs/config.yaml index 27d11a5..0bf11c3 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -1,5 +1,5 @@ dataset_id: "ariG23498/license-detection-paligemma" -model_id: "unsloth/gemma-3-4b-it" #"google/gemma-3-4b-pt" +model_id: "google/gemma-3-4b-pt" # "unsloth/gemma-3-4b-it" checkpoint_id: "ajaymin28/Gemma3_ObjeDet" device: "cuda" diff --git a/train.py b/train.py index 08a2b0e..378def5 100644 --- a/train.py +++ b/train.py @@ -36,25 +36,6 @@ A.ColorJitter(p=0.2), ], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids'], filter_invalid_bboxes=True)) -# TODO: Delete this after testing get_dataloader() with is_unsloth=True flag -# def get_dataloader_unsloth(tokenizer, args, dtype, split="train"): -# logger.info("Fetching the dataset") -# train_dataset = load_dataset(args.dataset_id, split=split) # or cfg.dataset_id -# train_collate_fn = partial( -# train_collate_function_unsloth, -# tokenizer=tokenizer, # <- Use the Unsloth tokenizer instead of processor -# dtype=dtype, -# transform=augmentations -# ) - -# logger.info("Building data loader") -# train_dataloader = DataLoader( -# train_dataset, -# batch_size=args.batch_size, -# collate_fn=train_collate_fn, -# shuffle=True, -# ) -# return train_dataloader def get_dataloader(args:Configuration,processor=None,tokenizer=None, split="train", is_unsloth=False): logger.info(f"Fetching the dataset: {cfg.dataset_id}:{split}") @@ -189,13 +170,6 @@ def train_model(model, optimizer, cfg:Configuration, train_loader, val_loader=No scaler = GradScaler() if use_fp16 else None global_step, best_val_loss = 0, float("inf") - # total_trainable = 0 - # total_params = 0 - # for n, p in model.named_parameters(): - # total_params += p.numel() - # if p.requires_grad: - # total_trainable += p.numel() - # logger.info(f"Total trainable parameters before train(): {total_trainable:,}") # if cfg.use_unsloth and FastModel is not None: # # logger.info("Before setting for training...") @@ -208,14 +182,6 @@ def train_model(model, optimizer, cfg:Configuration, train_loader, val_loader=No model.train() model.to(cfg.device) - # total_trainable = 0 - # total_params = 0 - # for n, p in model.named_parameters(): - # total_params += p.numel() - # if p.requires_grad: - # total_trainable += p.numel() - # logger.info(f"Total trainable parameters after train(): {total_trainable:,}") - logger.info("after setting for training...") # model.print_trainable_parameters() @@ -229,11 +195,11 @@ def train_model(model, optimizer, cfg:Configuration, train_loader, val_loader=No loss = step(model, batch, cfg.device, use_fp16, optimizer, scaler) if global_step % 1 == 0: logger.info(f"Epoch:{epoch} Step:{global_step} Loss:{loss:.4f}") - # wandb.log({"train/loss": loss, "epoch": epoch}, step=global_step) + wandb.log({"train/loss": loss, "epoch": epoch}, step=global_step) if val_loader and global_step % cfg.validate_steps_freq == 0: val_loss = validate_all(model, val_loader, cfg, use_fp16, val_batches=1) # if val_batches>0 the code will validate on that many batches only. -1 to disable this logger.info(f"Step:{global_step} Val Loss:{val_loss:.4f}") - # wandb.log({"val/loss": val_loss, "epoch": epoch}, step=global_step) + wandb.log({"val/loss": val_loss, "epoch": epoch}, step=global_step) if val_loss < best_val_loss: best_val_loss = val_loss @@ -368,12 +334,12 @@ def load_model(cfg:Configuration): params_to_train = list(filter(lambda x: x.requires_grad, model.parameters())) optimizer = torch.optim.AdamW(params_to_train, lr=cfg.learning_rate) - # # 5. Enable logging, need to login or set wanddb token in os.env - # wandb.init( - # project=cfg.wandb_project_name, - # name=cfg.run_name if hasattr(cfg, "run_name") else None, - # config=vars(cfg), - # ) + # 5. Enable logging, need to login or set wanddb token in os.env + wandb.init( + project=cfg.wandb_project_name, + name=cfg.run_name if hasattr(cfg, "run_name") else None, + config=vars(cfg), + ) # 5. Actual train and validation, validation_dataloader=None to do just traing. train_model(model, optimizer, cfg, train_dataloader, validation_dataloader) @@ -388,5 +354,5 @@ def load_model(cfg:Configuration): # 8. Wrap up - # wandb.finish() + wandb.finish() logger.info("Train finished") diff --git a/utils/config.py b/utils/config.py index 72dd573..e8befe8 100644 --- a/utils/config.py +++ b/utils/config.py @@ -29,7 +29,7 @@ class UserLoRAConfig: @dataclass class Configuration: dataset_id: str = "ariG23498/license-detection-paligemma" - model_id: str = "unsloth/gemma-3-4b-it" #"google/gemma-3-4b-pt" + model_id: str = "google/gemma-3-4b-pt" # "unsloth/gemma-3-4b-it" checkpoint_id: str = "ajaymin28/Gemma3_ObjeDet" push_model_to_hub: bool = False device: str = "cuda" if torch.cuda.is_available() else "cpu"