diff --git a/config.py b/config.py deleted file mode 100644 index 333d864..0000000 --- a/config.py +++ /dev/null @@ -1,19 +0,0 @@ -from dataclasses import dataclass - -import torch - - -@dataclass -class Configuration: - dataset_id: str = "ariG23498/license-detection-paligemma" - - model_id: str = "google/gemma-3-4b-pt" - checkpoint_id: str = "sergiopaniego/gemma-3-4b-pt-object-detection-aug" - - device: str = "cuda" if torch.cuda.is_available() else "cpu" - dtype: torch.dtype = torch.bfloat16 - - batch_size: int = 8 - learning_rate: float = 2e-05 - epochs = 2 - diff --git a/configs/config.yaml b/configs/config.yaml new file mode 100644 index 0000000..0bf11c3 --- /dev/null +++ b/configs/config.yaml @@ -0,0 +1,25 @@ +dataset_id: "ariG23498/license-detection-paligemma" +model_id: "google/gemma-3-4b-pt" # "unsloth/gemma-3-4b-it" +checkpoint_id: "ajaymin28/Gemma3_ObjeDet" + +device: "cuda" +dtype: "bfloat16" + +batch_size: 16 +learning_rate: 2e-5 +epochs: 1 +max_step_to_train: 100 +validate_steps_freq: 10 + +finetune_method: "qlora" # FFT | lora | qlora +use_unsloth: false + + +mm_tunable_parts: + - no_exist_layer # basically not finetuning any base components + # - mlp + # - multi_modal_projector + # - vision_tower + # - language_model +wandb_project_name: "Gemma3_LoRA" +push_model_to_hub: true \ No newline at end of file diff --git a/configs/lora_config.yaml b/configs/lora_config.yaml new file mode 100644 index 0000000..b4943c5 --- /dev/null +++ b/configs/lora_config.yaml @@ -0,0 +1,17 @@ +r: 32 +alpha: 32 +dropout: 0.05 +target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - up_proj + - down_proj + - gate_proj +max_seq_length: 2048 # Unsloth will RoPE-scale + +# LoRA-specific: no quantization +load_in_4bit: false +load_in_8bit: false +quantization_config: null \ No newline at end of file diff --git a/configs/qlora_config.yaml b/configs/qlora_config.yaml new file mode 100644 index 0000000..dc4d134 --- /dev/null +++ b/configs/qlora_config.yaml @@ -0,0 +1,20 @@ +r: 32 +alpha: 32 +dropout: 0.05 +target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - up_proj + - down_proj + - gate_proj +max_seq_length: 2048 # Unsloth will RoPE-scale + +# QLoRA-specific: quantization enabled +load_in_4bit: true +load_in_8bit: false +quantization_config: + bnb_4bit_use_double_quant: true + bnb_4bit_quant_type: "nf4" + bnb_4bit_compute_dtype: "bfloat16" diff --git a/predict.py b/predict.py index 4d49652..c8f1303 100644 --- a/predict.py +++ b/predict.py @@ -5,8 +5,8 @@ from torch.utils.data import DataLoader from transformers import AutoProcessor, Gemma3ForConditionalGeneration -from config import Configuration -from utils import test_collate_function, visualize_bounding_boxes +from utils.config import Configuration +from utils.utilities import test_collate_function, visualize_bounding_boxes os.makedirs("outputs", exist_ok=True) @@ -23,7 +23,7 @@ def get_dataloader(processor): if __name__ == "__main__": - cfg = Configuration() + cfg = Configuration.from_args() processor = AutoProcessor.from_pretrained(cfg.checkpoint_id) model = Gemma3ForConditionalGeneration.from_pretrained( cfg.checkpoint_id, diff --git a/requirements.txt b/requirements.txt index 84ddc32..c0f4b21 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,6 @@ wandb peft albumentations argparse +omegaconf +unsloth==2025.5.7 +unsloth-zoo==2025.5.8 \ No newline at end of file diff --git a/train.py b/train.py index 8aab73d..378def5 100644 --- a/train.py +++ b/train.py @@ -1,22 +1,34 @@ + import logging + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +FastModel = None +# Optional – comment below imports if you are not planinng to use unsloth +# try: from unsloth import FastModel +# except ImportError as e: logger.warning(f"Unsloth import error : {e}") +# except NotImplementedError as e: logger.warning(f"Unsloth NotImplementedError error : {e}") + + import wandb from functools import partial - import torch +from torch.amp import autocast, GradScaler from datasets import load_dataset from torch.utils.data import DataLoader from transformers import AutoProcessor, Gemma3ForConditionalGeneration +from transformers import BitsAndBytesConfig -from config import Configuration -from utils import train_collate_function -import argparse +from utils.config import Configuration +from utils.utilities import train_collate_function, train_collate_function_unsloth +from utils.utilities import save_best_model, push_to_hub, load_saved_model +from peft import get_peft_config, get_peft_model, LoraConfig +from peft import prepare_model_for_kbit_training import albumentations as A -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) -logger = logging.getLogger(__name__) - augmentations = A.Compose([ A.Resize(height=896, width=896), @@ -25,12 +37,15 @@ ], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids'], filter_invalid_bboxes=True)) -def get_dataloader(processor, args, dtype): - logger.info("Fetching the dataset") - train_dataset = load_dataset(cfg.dataset_id, split="train") - train_collate_fn = partial( - train_collate_function, processor=processor, dtype=dtype, transform=augmentations - ) +def get_dataloader(args:Configuration,processor=None,tokenizer=None, split="train", is_unsloth=False): + logger.info(f"Fetching the dataset: {cfg.dataset_id}:{split}") + train_dataset = load_dataset(cfg.dataset_id, split=split) + + if is_unsloth: + # <- Use the Unsloth tokenizer instead of processor + train_collate_fn = partial(train_collate_function_unsloth,tokenizer=tokenizer,dtype=args.dtype,transform=augmentations) + else: + train_collate_fn = partial(train_collate_function, processor=processor, dtype=args.dtype, transform=augmentations) logger.info("Building data loader") train_dataloader = DataLoader( @@ -38,74 +53,306 @@ def get_dataloader(processor, args, dtype): batch_size=args.batch_size, collate_fn=train_collate_fn, shuffle=True, + pin_memory=True, ) return train_dataloader +def step(model, batch, device, use_fp16, optimizer=None, scaler=None): + """ + Single batch process + """ + data = batch.to(device) + if use_fp16: + with autocast(device_type=device): + loss = model(**data).loss + else: + loss = model(**data).loss + if optimizer: + optimizer.zero_grad() + if use_fp16: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + optimizer.step() + return loss.item() + +def validate_all(model, val_loader, cfg, use_fp16,val_batches=5): + + # if cfg.use_unsloth and FastModel is not None: + # FastModel.for_inference(model) # Enable for inference! + # else: + # model.eval() + model.eval() + + with torch.no_grad(): + if val_batches>0: + ## TODO: This logic is Temp and should be removed in final clean up + n_batches = val_batches + losses = [] + for i, batch in enumerate(val_loader): + if i >= n_batches: + break + losses.append(step(model, batch, cfg.device, use_fp16)) + else: + losses = [step(model, batch, cfg.device, use_fp16) for batch in val_loader] + model.train() + return sum(losses) / len(losses) if len(losses)> 0 else 0 + +import psutil +import os + +def memory_stats(get_dict=False, print_mem_usage=True, device=None): + stats = { + "cpu": "", + "ram": "", + "cuda_free": "", + "cuda_total": "", + "cuda_allocated": "", + "cuda_reserved": "", + "peak_vram_allocated_mb": "", + } + + cuda_freeMem = 0 + cuda_total = 0 + cuda_allocated = 0 + cuda_reserved = 0 + peak_vram_allocated_bytes = 0 + peak_vram_allocated_mb = 0 + + if torch.cuda.is_available(): + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + try: + cuda_freeMem, cuda_total = torch.cuda.mem_get_info() + cuda_total = cuda_total/1024**2 + cuda_freeMem = cuda_freeMem/1024**2 + except: pass + + try: + cuda_allocated = torch.cuda.memory_allocated()/1024**2 + cuda_reserved = torch.cuda.memory_reserved()/1024**2 + except: pass + + try: + peak_vram_allocated_bytes = torch.cuda.max_memory_allocated(device) + peak_vram_allocated_mb = peak_vram_allocated_bytes / (1024 ** 2) + except: pass + + stats["cuda_free"] = cuda_freeMem + stats["cuda_total"] = cuda_total + stats["cuda_allocated"] = round(cuda_allocated,3) + stats["cuda_reserved"] = round(cuda_reserved,3) + stats["peak_vram_allocated_mb"] = round(peak_vram_allocated_mb,3) + + process = psutil.Process(os.getpid()) + ram_mem_perc = process.memory_percent() + cpu_usage = psutil.cpu_percent() + + stats["cpu"] = cpu_usage + stats["ram"] = ram_mem_perc + + if print_mem_usage: + logger.info(f"CPU: {cpu_usage:.2f}% RAM: {ram_mem_perc:.2f}% GPU memory Total: [{cuda_total:.2f}] Available: [{cuda_freeMem:.2f}] Allocated: [{cuda_allocated:.2f}] Reserved: [{cuda_reserved:.2f}] Cuda Peak Mem: {peak_vram_allocated_mb:.2f}") + + if get_dict: + return stats + +def train_model(model, optimizer, cfg:Configuration, train_loader, val_loader=None): + + memory_stats() + torch.cuda.empty_cache() # TODO: Do I need this? Just want to make sure I have mem cleaned up before training starts. + logger.info(f"called: torch.cuda.empty_cache()") + memory_stats() + use_fp16 = cfg.dtype in [torch.float16, torch.bfloat16] + scaler = GradScaler() if use_fp16 else None + global_step, best_val_loss = 0, float("inf") + + + # if cfg.use_unsloth and FastModel is not None: + # # logger.info("Before setting for training...") + # # model.print_trainable_parameters() + # FastModel.for_training(model, use_gradient_checkpointing=True) # Enable for training! # TODO :calling this method uses so much memory, investigate + # else: + # model.train() + # model.to(cfg.device) + + model.train() + model.to(cfg.device) + + logger.info("after setting for training...") + # model.print_trainable_parameters() + -def train_model(model, optimizer, cfg, train_dataloader): - logger.info("Start training") - global_step = 0 for epoch in range(cfg.epochs): - for idx, batch in enumerate(train_dataloader): - outputs = model(**batch.to(model.device)) - loss = outputs.loss - if idx % 100 == 0: - logger.info(f"Epoch: {epoch} Iter: {idx} Loss: {loss.item():.4f}") - wandb.log({"train/loss": loss.item(), "epoch": epoch}, step=global_step) + for idx, batch in enumerate(train_loader): - loss.backward() - optimizer.step() - optimizer.zero_grad() + torch.cuda.reset_peak_memory_stats(cfg.device) + torch.cuda.empty_cache() + + loss = step(model, batch, cfg.device, use_fp16, optimizer, scaler) + if global_step % 1 == 0: + logger.info(f"Epoch:{epoch} Step:{global_step} Loss:{loss:.4f}") + wandb.log({"train/loss": loss, "epoch": epoch}, step=global_step) + if val_loader and global_step % cfg.validate_steps_freq == 0: + val_loss = validate_all(model, val_loader, cfg, use_fp16, val_batches=1) # if val_batches>0 the code will validate on that many batches only. -1 to disable this + logger.info(f"Step:{global_step} Val Loss:{val_loss:.4f}") + wandb.log({"val/loss": val_loss, "epoch": epoch}, step=global_step) + + if val_loss < best_val_loss: + best_val_loss = val_loss + save_best_model(model, cfg, tokenizer, cfg.finetune_method in {"lora", "qlora"}, logger) + + ## Model seem to converge before even first epoch finishes for LoRA. set max_step_to_train<=0 to disable this. + if global_step>cfg.max_step_to_train-1 and cfg.max_step_to_train>0: + break global_step += 1 + if global_step % 5 == 0: + memory_stats() + + if global_step % 5 == 0: + memory_stats() + return model -if __name__ == "__main__": - cfg = Configuration.from_args() - - # Get values dynamicaly from user - parser = argparse.ArgumentParser(description="Training for PaLiGemma") - parser.add_argument('--model_id', type=str, required=True, default=cfg.model_id, help='Enter Huggingface Model ID') - parser.add_argument('--dataset_id', type=str, required=True ,default=cfg.dataset_id, help='Enter Huggingface Dataset ID') - parser.add_argument('--batch_size', type=int, default=cfg.batch_size, help='Enter Batch Size') - parser.add_argument('--lr', type=float, default=cfg.learning_rate, help='Enter Learning Rate') - parser.add_argument('--checkpoint_id', type=str, required=True, default=cfg.checkpoint_id, help='Enter Huggingface Repo ID to push model') - - args = parser.parse_args() - processor = AutoProcessor.from_pretrained(args.model_id) - train_dataloader = get_dataloader(processor=processor, args=args, dtype=cfg.dtype) - - logger.info("Getting model & turning only attention parameters to trainable") - model = Gemma3ForConditionalGeneration.from_pretrained( - cfg.model_id, - torch_dtype=cfg.dtype, - device_map="cpu", - attn_implementation="eager", - ) - for name, param in model.named_parameters(): - if "attn" in name: - param.requires_grad = True +def load_model(cfg:Configuration): + + lcfg = cfg.lora + tokenizer = None + + if cfg.use_unsloth and FastModel is not None: + + # TODO: For LoRA and QLoRa change unsloth config accordigly, generally load_in_4bit, load_in_8bit will be False or LoRA + model, tokenizer = FastModel.from_pretrained( + model_name = "unsloth/gemma-3-4b-it", + max_seq_length = 2048, # Choose any for long context! + load_in_4bit = True, # 4 bit quantization to reduce memory + load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory + full_finetuning = False, # [NEW!] We have full finetuning now! + # token = os.environ["HF_TOKEN"] # TODO: Handle this + ) + + if cfg.finetune_method in {"lora", "qlora"}: + model = FastModel.get_peft_model( + model, + finetune_vision_layers = True, # Turn off for just text! + finetune_language_layers = True, # Should leave on! + finetune_attention_modules = True, # Attention good for GRPO + finetune_mlp_modules = True, # SHould leave on always! + + r=lcfg.r, # Larger = higher accuracy, but might overfit + lora_alpha=lcfg.alpha, # Recommended alpha == r at least + lora_dropout=lcfg.dropout, + bias = "none", + random_state = 3407 + # TODO add rs_lora and dora + ) + + + else: + quant_args = {} + # Enable quantization only for QLoRA + if cfg.finetune_method in {"qlora"}: + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=cfg.dtype, + ) + quant_args = {"quantization_config": bnb_config, "device_map": "auto"} + + model = Gemma3ForConditionalGeneration.from_pretrained( + cfg.model_id, + torch_dtype=cfg.dtype, + attn_implementation="eager", + **quant_args, + ) + + + if cfg.finetune_method in {"lora", "qlora"}: + + if cfg.finetune_method=="qlora": + model = prepare_model_for_kbit_training(model) + + + lora_cfg = LoraConfig( + r=lcfg.r, + lora_alpha=lcfg.alpha, + target_modules=lcfg.target_modules, + lora_dropout=lcfg.dropout, + bias="none", + use_dora=True if cfg.finetune_method=="qlora" else False, + use_rslora=True # Rank-Stabilized LoRA --> `lora_alpha/math.sqrt(r)` + ) + + model = get_peft_model(model, lora_cfg) + memory_stats() + torch.cuda.empty_cache() # TODO: Do I need this? Just want to make sure I have mem cleaned up before training starts. + logger.info(f"called: torch.cuda.empty_cache()") + memory_stats() + + elif cfg.finetune_method == "FFT": + # handled below before printing params + pass else: - param.requires_grad = False + raise ValueError(f"Unknown finetune_method: {cfg.finetune_method}") + + for n, p in model.named_parameters(): + if cfg.finetune_method == "FFT": # TODO: should FFT finetune all components? or just some, change FFT name to just FT? + p.requires_grad = any(part in n for part in cfg.mm_tunable_parts) + if p.requires_grad: + print(f"{n} will be finetuned") + + if cfg.finetune_method in {"lora", "qlora"}: + model.print_trainable_parameters() - model.train() - model.to(cfg.device) + return model, tokenizer + + +if __name__ == "__main__": + # 1. Parse CLI + YAMLs into config + cfg = Configuration.from_args() # config.yaml is overriden by CLI arguments + + # 2. Load model + logger.info(f"Getting model for {cfg.finetune_method}") + # loads model based on config. Unsloth, lora, qlora, FFT + model, tokenizer = load_model(cfg) + # 3. Get Data + if cfg.use_unsloth: + train_dataloader = get_dataloader(args=cfg,tokenizer=tokenizer, split="train",is_unsloth=True) + validation_dataloader = get_dataloader(args=cfg,tokenizer=tokenizer, split="validation",is_unsloth=True) + else: + processor = AutoProcessor.from_pretrained(cfg.model_id) + train_dataloader = get_dataloader(args=cfg, processor=processor, split="train") + validation_dataloader = get_dataloader(args=cfg, processor=processor, split="validation") + # Credits to Sayak Paul for this beautiful expression params_to_train = list(filter(lambda x: x.requires_grad, model.parameters())) - optimizer = torch.optim.AdamW(params_to_train, lr=args.lr) + optimizer = torch.optim.AdamW(params_to_train, lr=cfg.learning_rate) + # 5. Enable logging, need to login or set wanddb token in os.env wandb.init( - project=cfg.project_name, + project=cfg.wandb_project_name, name=cfg.run_name if hasattr(cfg, "run_name") else None, config=vars(cfg), ) - train_model(model, optimizer, cfg, train_dataloader) + # 5. Actual train and validation, validation_dataloader=None to do just traing. + train_model(model, optimizer, cfg, train_dataloader, validation_dataloader) - # Push the checkpoint to hub - model.push_to_hub(cfg.checkpoint_id) - processor.push_to_hub(cfg.checkpoint_id) + # 6. Loading best model back + model, tokenizer = load_saved_model(cfg, is_lora=cfg.finetune_method in {"lora", "qlora"}, device="cuda", logger=logger) + logger.info(f"Pushing to hub at: {cfg.checkpoint_id}") + if cfg.push_model_to_hub: + push_to_hub(model, cfg, tokenizer, cfg.finetune_method in {"lora", "qlora"}) + # 7. Test? # TODO + + + # 8. Wrap up wandb.finish() logger.info("Train finished") diff --git a/utils.py b/utils.py deleted file mode 100644 index 2cee681..0000000 --- a/utils.py +++ /dev/null @@ -1,108 +0,0 @@ -import re - -import matplotlib.pyplot as plt -import numpy as np -from PIL import ImageDraw - -from create_dataset import format_objects - -def parse_paligemma_label(label, width, height): - # Extract location codes - loc_pattern = r"" - locations = [int(loc) for loc in re.findall(loc_pattern, label)] - - # Extract category (everything after the last location code) - category = label.split(">")[-1].strip() - - # Convert normalized locations back to original image coordinates - # Order in PaliGemma format is: y1, x1, y2, x2 - y1_norm, x1_norm, y2_norm, x2_norm = locations - - # Convert normalized coordinates to actual coordinates - x1 = (x1_norm / 1024) * width - y1 = (y1_norm / 1024) * height - x2 = (x2_norm / 1024) * width - y2 = (y2_norm / 1024) * height - - return category, [x1, y1, x2, y2] - - -def visualize_bounding_boxes(image, label, width, height, name): - # Create a copy of the image to draw on - draw_image = image.copy() - draw = ImageDraw.Draw(draw_image) - - # Parse the label - category, bbox = parse_paligemma_label(label, width, height) - - # Draw the bounding box - draw.rectangle(bbox, outline="red", width=2) - - # Add category label - draw.text((bbox[0], max(0, bbox[1] - 10)), category, fill="red") - - # Show the image - plt.figure(figsize=(10, 6)) - plt.imshow(draw_image) - plt.axis("off") - plt.title(f"Bounding Box: {category}") - plt.tight_layout() - plt.savefig(name) - plt.show() - plt.close() - - -def train_collate_function(batch_of_samples, processor, dtype, transform=None): - images = [] - prompts = [] - for sample in batch_of_samples: - if transform: - transformed = transform(image=np.array(sample["image"]), bboxes=sample["objects"]["bbox"], category_ids=sample["objects"]["category"]) - sample["image"] = transformed["image"] - sample["objects"]["bbox"] = transformed["bboxes"] - sample["objects"]["category"] = transformed["category_ids"] - sample["height"] = sample["image"].shape[0] - sample["width"] = sample["image"].shape[1] - sample['label_for_paligemma'] = format_objects(sample)['label_for_paligemma'] - images.append([sample["image"]]) - prompts.append( - f"{processor.tokenizer.boi_token} detect \n\n{sample['label_for_paligemma']} {processor.tokenizer.eos_token}" - ) - - batch = processor(images=images, text=prompts, return_tensors="pt", padding=True) - - # The labels are the input_ids, and we mask the padding tokens in the loss computation - labels = batch["input_ids"].clone() # Clone input IDs for labels - - # List from https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora - # Mask image tokens - image_token_id = [ - processor.tokenizer.convert_tokens_to_ids( - processor.tokenizer.special_tokens_map["boi_token"] - ) - ] - # Mask tokens for not being used in the loss computation - labels[labels == processor.tokenizer.pad_token_id] = -100 - labels[labels == image_token_id] = -100 - labels[labels == 262144] = -100 - - batch["labels"] = labels - - batch["pixel_values"] = batch["pixel_values"].to( - dtype - ) # to check with the implementation - return batch - - -def test_collate_function(batch_of_samples, processor, dtype): - images = [] - prompts = [] - for sample in batch_of_samples: - images.append([sample["image"]]) - prompts.append(f"{processor.tokenizer.boi_token} detect \n\n") - - batch = processor(images=images, text=prompts, return_tensors="pt", padding=True) - batch["pixel_values"] = batch["pixel_values"].to( - dtype - ) # to check with the implementation - return batch, images diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/config.py b/utils/config.py new file mode 100644 index 0000000..e8befe8 --- /dev/null +++ b/utils/config.py @@ -0,0 +1,123 @@ +import argparse +import torch +from dataclasses import dataclass, field +from typing import List +from omegaconf import OmegaConf + +def str2bool(v): + if isinstance(v, bool): return v + if v.lower() in ('yes', 'true', 't', '1'): return True + if v.lower() in ('no', 'false', 'f', '0'): return False + raise argparse.ArgumentTypeError("Boolean value expected.") + + +@dataclass +class UserLoRAConfig: + r: int = 32 + alpha: int = 32 + dropout: float = 0.05 + target_modules: List[str] = field(default_factory=lambda: [ + "q_proj", "k_proj", "v_proj", "o_proj", + "up_proj", "down_proj", "gate_proj" + ]) + max_seq_length: int = 2048 + #QLoRA + load_in_4bit: bool = True + load_in_8bit: bool = False # more precise bet takes more mem + + +@dataclass +class Configuration: + dataset_id: str = "ariG23498/license-detection-paligemma" + model_id: str = "google/gemma-3-4b-pt" # "unsloth/gemma-3-4b-it" + checkpoint_id: str = "ajaymin28/Gemma3_ObjeDet" + push_model_to_hub: bool = False + device: str = "cuda" if torch.cuda.is_available() else "cpu" + dtype: torch.dtype = torch.bfloat16 + validate_steps_freq: int = 10 + batch_size: int = 16 + learning_rate: float = 2e-5 + epochs: int = 1 + max_step_to_train: int = 100 # if model converges before training one epoch, set to 0 or -1 to disable + finetune_method: str = "lora" # FFT | lora | qlora + use_unsloth: bool = False + mm_tunable_parts: List[str] = field(default_factory=lambda: ["multi_modal_projector"]) # vision_tower,language_model + lora: UserLoRAConfig = field(default_factory=UserLoRAConfig) + wandb_project_name: str = "Gemma3_LoRA" + + @classmethod + def load(cls, main_cfg_path="configs/config.yaml", lora_cfg_path="configs/lora_config.yaml"): + base_cfg = OmegaConf.load(main_cfg_path) + lora_cfg = OmegaConf.load(lora_cfg_path) # TODO: Merge config into one, refer to hydra config. + base_cfg.lora = lora_cfg + return OmegaConf.to_container(base_cfg, resolve=True) + + @classmethod + def from_args(cls): + cfg_dict = cls.load() # Load YAML as dict + parser = argparse.ArgumentParser() + + # Top-level args + parser.add_argument("--dataset_id", type=str, default=cfg_dict["dataset_id"]) + parser.add_argument("--model_id", type=str, default=cfg_dict["model_id"]) + parser.add_argument("--checkpoint_id", type=str, default=cfg_dict["checkpoint_id"]) + parser.add_argument("--push_model_to_hub", type=str2bool, default=cfg_dict["push_model_to_hub"]) + parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default=cfg_dict["device"]) + parser.add_argument("--dtype", type=str, choices=["float32", "float16", "bfloat16"], default="float16") + parser.add_argument("--batch_size", type=int, default=cfg_dict["batch_size"]) + parser.add_argument("--learning_rate", type=float, default=cfg_dict["learning_rate"]) + parser.add_argument("--epochs", type=int, default=cfg_dict["epochs"]) + parser.add_argument("--max_step_to_train", type=int, default=cfg_dict["max_step_to_train"]) + parser.add_argument("--validate_steps_freq", type=int, default=cfg_dict["validate_steps_freq"]) + parser.add_argument("--finetune_method", type=str, choices=["FFT", "lora", "qlora"], default=cfg_dict["finetune_method"]) + parser.add_argument("--use_unsloth", type=str2bool, default=cfg_dict["use_unsloth"]) + parser.add_argument("--mm_tunable_parts", type=str, default=",".join(cfg_dict["mm_tunable_parts"])) + + # LoRA nested config overrides + parser.add_argument("--lora.r", type=int, default=cfg_dict["lora"]["r"]) + parser.add_argument("--lora.alpha", type=int, default=cfg_dict["lora"]["alpha"]) + parser.add_argument("--lora.dropout", type=float, default=cfg_dict["lora"]["dropout"]) + parser.add_argument("--lora.target_modules", type=str, default=",".join(cfg_dict["lora"]["target_modules"])) + parser.add_argument("--lora.max_seq_length", type=int, default=cfg_dict["lora"]["max_seq_length"]) + parser.add_argument("--lora.load_in_4bit", type=str2bool,default=cfg_dict["lora"]["load_in_4bit"]) + parser.add_argument("--lora.load_in_8bit", type=str2bool,default=cfg_dict["lora"]["load_in_8bit"]) + + parser.add_argument("--wandb_project_name", type=str, default=cfg_dict["wandb_project_name"]) + + args = parser.parse_args() + + dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + + lora_config = UserLoRAConfig( + r=args.__dict__["lora.r"], + alpha=args.__dict__["lora.alpha"], + dropout=args.__dict__["lora.dropout"], + target_modules=[x.strip() for x in args.__dict__["lora.target_modules"].split(',')], + max_seq_length=args.__dict__["lora.max_seq_length"], + load_in_4bit=args.__dict__["lora.load_in_4bit"], + load_in_8bit=args.__dict__["lora.load_in_8bit"], + ) + + # TODO handle this long list, probably migrate to hydra conf. + return cls( + dataset_id=args.dataset_id, + model_id=args.model_id, + checkpoint_id=args.checkpoint_id, + device=args.device, + dtype=dtype_map[args.dtype], + batch_size=args.batch_size, + learning_rate=args.learning_rate, + epochs=args.epochs, + finetune_method=args.finetune_method, + use_unsloth=args.use_unsloth, + mm_tunable_parts=[x.strip() for x in args.mm_tunable_parts.split(',')], + lora=lora_config, + wandb_project_name=args.wandb_project_name, + max_step_to_train=args.max_step_to_train, + push_model_to_hub=args.push_model_to_hub, + validate_steps_freq=args.validate_steps_freq + ) \ No newline at end of file diff --git a/create_dataset.py b/utils/create_dataset.py similarity index 96% rename from create_dataset.py rename to utils/create_dataset.py index 6dce684..825a5b3 100644 --- a/create_dataset.py +++ b/utils/create_dataset.py @@ -1,6 +1,5 @@ from datasets import load_dataset import argparse -from config import Configuration def coco_to_xyxy(coco_bbox): x, y, width, height = coco_bbox @@ -38,6 +37,8 @@ def format_objects(example): if __name__ == "__main__": + from utils.config import Configuration # To avoid circular import error + # Support for generic script for dataset cfg = Configuration() parser = argparse.ArgumentParser(description='Process dataset for PaLiGemma') diff --git a/utils/utilities.py b/utils/utilities.py new file mode 100644 index 0000000..214a251 --- /dev/null +++ b/utils/utilities.py @@ -0,0 +1,234 @@ +import os +import re +import argparse +import matplotlib.pyplot as plt +import numpy as np +from PIL import ImageDraw +import torch +from utils.create_dataset import format_objects +from transformers import AutoModel, AutoTokenizer # Change to your model class if needed +from peft import PeftModel, PeftConfig + +def parse_paligemma_label(label, width, height): + # Extract location codes + loc_pattern = r"" + locations = [int(loc) for loc in re.findall(loc_pattern, label)] + + # Extract category (everything after the last location code) + category = label.split(">")[-1].strip() + + # Convert normalized locations back to original image coordinates + # Order in PaliGemma format is: y1, x1, y2, x2 + y1_norm, x1_norm, y2_norm, x2_norm = locations + + # Convert normalized coordinates to actual coordinates + x1 = (x1_norm / 1024) * width + y1 = (y1_norm / 1024) * height + x2 = (x2_norm / 1024) * width + y2 = (y2_norm / 1024) * height + + return category, [x1, y1, x2, y2] + + +def visualize_bounding_boxes(image, label, width, height, name): + # Create a copy of the image to draw on + draw_image = image.copy() + draw = ImageDraw.Draw(draw_image) + + # Parse the label + category, bbox = parse_paligemma_label(label, width, height) + + # Draw the bounding box + draw.rectangle(bbox, outline="red", width=2) + + # Add category label + draw.text((bbox[0], max(0, bbox[1] - 10)), category, fill="red") + + # Show the image + plt.figure(figsize=(10, 6)) + plt.imshow(draw_image) + plt.axis("off") + plt.title(f"Bounding Box: {category}") + plt.tight_layout() + plt.savefig(name) + plt.show() + plt.close() + +def train_collate_function_unsloth(batch_of_samples, tokenizer, dtype, transform=None): + """ + unsloth + """ + images = [] + prompts = [] + for sample in batch_of_samples: + if transform: + transformed = transform( + image=np.array(sample["image"]), + bboxes=sample["objects"]["bbox"], + category_ids=sample["objects"]["category"] + ) + sample["image"] = transformed["image"] + sample["objects"]["bbox"] = transformed["bboxes"] + sample["objects"]["category"] = transformed["category_ids"] + sample["height"] = sample["image"].shape[0] + sample["width"] = sample["image"].shape[1] + sample['label_for_paligemma'] = format_objects(sample)['label_for_paligemma'] + images.append([sample["image"]]) + prompts.append( + f"{tokenizer.boi_token} detect \n\n{sample['label_for_paligemma']} {tokenizer.eos_token}" + ) + + # Use tokenizer directly (Unsloth tokenizer supports vision inputs for Gemma3) + batch = tokenizer( + images=images, + text=prompts, + return_tensors="pt", + padding=True + ) + + labels = batch["input_ids"].clone() + + # Mask out padding, image tokens, and other special tokens from loss + image_token_id = [ + tokenizer.tokenizer.convert_tokens_to_ids(tokenizer.boi_token) + ] + labels[labels == tokenizer.pad_token_id] = -100 + for tok_id in image_token_id: + labels[labels == tok_id] = -100 + labels[labels == 262144] = -100 # If this ID is used for your "unused" special token + + batch["labels"] = labels + if "pixel_values" in batch: + batch["pixel_values"] = batch["pixel_values"].to(dtype) + + return batch + +def train_collate_function(batch_of_samples, processor, dtype, transform=None): + images = [] + prompts = [] + for sample in batch_of_samples: + if transform: + transformed = transform(image=np.array(sample["image"]), bboxes=sample["objects"]["bbox"], category_ids=sample["objects"]["category"]) + sample["image"] = transformed["image"] + sample["objects"]["bbox"] = transformed["bboxes"] + sample["objects"]["category"] = transformed["category_ids"] + sample["height"] = sample["image"].shape[0] + sample["width"] = sample["image"].shape[1] + sample['label_for_paligemma'] = format_objects(sample)['label_for_paligemma'] + images.append([sample["image"]]) + prompts.append( + f"{processor.tokenizer.boi_token} detect \n\n{sample['label_for_paligemma']} {processor.tokenizer.eos_token}" + ) + + batch = processor(images=images, text=prompts, return_tensors="pt", padding=True) + + # The labels are the input_ids, and we mask the padding tokens in the loss computation + labels = batch["input_ids"].clone() # Clone input IDs for labels + + # List from https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora + # Mask image tokens + image_token_id = [ + processor.tokenizer.convert_tokens_to_ids( + processor.tokenizer.special_tokens_map["boi_token"] + ) + ] + # Mask tokens for not being used in the loss computation + labels[labels == processor.tokenizer.pad_token_id] = -100 + labels[labels == image_token_id] = -100 + labels[labels == 262144] = -100 + + batch["labels"] = labels + + batch["pixel_values"] = batch["pixel_values"].to( + dtype + ) # to check with the implementation + return batch + + +def test_collate_function(batch_of_samples, processor, dtype): + images = [] + prompts = [] + for sample in batch_of_samples: + images.append([sample["image"]]) + prompts.append(f"{processor.tokenizer.boi_token} detect \n\n") + + batch = processor(images=images, text=prompts, return_tensors="pt", padding=True) + batch["pixel_values"] = batch["pixel_values"].to( + dtype + ) # to check with the implementation + return batch, images + +def str2bool(v): + """ + Helper function to parse boolean values from cli arguments + """ + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + +def push_to_hub(model, cfg, tokenizer=None, is_lora=False): + """ + Push model to huggingface + """ + push_kwargs = {} + if tokenizer is not None: + push_kwargs['tokenizer'] = tokenizer + model.push_to_hub(cfg.checkpoint_id, **push_kwargs) + if tokenizer is not None: + tokenizer.push_to_hub(cfg.checkpoint_id) + + + +def save_best_model(model, cfg, tokenizer=None, is_lora=False, logger=None): + """Save LoRA adapter or full model based on config.""" + save_path = f"checkpoints/{cfg.checkpoint_id}_best" + os.makedirs(save_path, exist_ok=True) + if is_lora: + if logger: logger.info(f"Saving LoRA adapter to {save_path}") + model.save_pretrained(save_path) + if tokenizer is not None: + tokenizer.save_pretrained(save_path) + else: + if logger: logger.info(f"Saving full model weights to {save_path}.pt") + torch.save(model.state_dict(), f"{save_path}.pt") + + +def load_saved_model(cfg, is_lora=False, device=None, logger=None): + """ + Load LoRA adapter or full model based on config. + Returns (model, tokenizer) + """ + save_path = f"checkpoints/{cfg.checkpoint_id}_best" + tokenizer = None + + if cfg.use_unsloth: + + if logger: logger.info(f"Loading LoRA adapter from {save_path}") + + from unsloth import FastModel + model, tokenizer = FastModel.from_pretrained( + model_name = save_path, # YOUR MODEL YOU USED FOR TRAINING + load_in_4bit = True, # Set to False for 16bit LoRA + ) + return model, tokenizer + + else: + if is_lora: + if logger: logger.info(f"Loading LoRA adapter from {save_path}") + # Load base model first, then LoRA weights + base_model = AutoModel.from_pretrained(cfg.model_id, device_map=device or "auto") + model = PeftModel.from_pretrained(base_model, save_path, device_map=device or "auto") + if os.path.exists(os.path.join(save_path, "tokenizer_config.json")): + tokenizer = AutoTokenizer.from_pretrained(save_path) + else: + if logger: logger.info(f"Loading full model weights from {save_path}.pt") + model = AutoModel.from_pretrained(cfg.model_id, device_map=device or "auto") + model.load_state_dict(torch.load(f"{save_path}.pt", map_location=device or "cpu")) + if os.path.exists(os.path.join(save_path, "tokenizer_config.json")): + tokenizer = AutoTokenizer.from_pretrained(save_path) + return model, tokenizer \ No newline at end of file