Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.ruff_cache
.venv
__pycache__
.env
__pycache__
wandb/
1 change: 1 addition & 0 deletions COMMIT_MSG
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
docs: correct inference script name in README
30 changes: 18 additions & 12 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
from dataclasses import dataclass

import torch


@dataclass
class Configuration:
dataset_id: str = "ariG23498/license-detection-paligemma"

project_name: str = "gemma-3-4b-pt-object-detection-aug" # "SmolVLM-256M-Instruct-object-detection-aug"
model_id: str = "google/gemma-3-4b-pt" # "HuggingFaceTB/SmolVLM-256M-Instruct"
checkpoint_id: str = "sergiopaniego/gemma-3-4b-pt-object-detection-loc-tokens" # "sergiopaniego/SmolVLM-256M-Instruct-object-detection"

project_name: str = "gemma-3-4b-pt-object-detection-qlora-test"
model_id: str = "google/gemma-3-4b-pt"
checkpoint_id: str = "test/gemma-3-4b-pt-qlora-adapter"
device: str = "cuda" if torch.cuda.is_available() else "cpu"
dtype: torch.dtype = "auto" # Change to torch.bfloat16 for "google/gemma-3-4b-pt"
attn_implementation = "eager" # flash_attention_2

batch_size: int = 4 # 8 for "google/gemma-3-4b-pt"
dtype: torch.dtype = "auto"
attn_implementation = "eager"

# Reduced for testing
batch_size: int = 1
learning_rate: float = 2e-05
epochs = 10
epochs = 10 # Set to 1 for a quick test, or a higher value (e.g., 10) for a full training run
best_model_output_dir: str = "outputs/best_model"

# QLoRA parameters
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.05
lora_target_modules: list = None
225 changes: 157 additions & 68 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import logging
import wandb
from functools import partial

import os
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoProcessor, AutoModelForVision2Seq, AutoModelForCausalLM
from torch.utils.data import DataLoader, random_split
from transformers import AutoProcessor, AutoModelForVision2Seq, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

from config import Configuration
from utils import train_collate_function, get_processor_with_new_tokens, get_model_with_resize_token_embeddings
from utils import train_collate_function, get_model_with_resize_token_embeddings
import argparse
import albumentations as A
import sys
from PIL import Image

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
Expand All @@ -19,11 +22,8 @@


def get_augmentations(cfg):
if "SmolVLM" in cfg.model_id:
resize_size = 512
else:
resize_size = 896

# This can be customized
resize_size = 896
augmentations = A.Compose([
A.Resize(height=resize_size, width=resize_size),
A.HorizontalFlip(p=0.5),
Expand All @@ -32,81 +32,115 @@ def get_augmentations(cfg):
return augmentations


def get_dataloaders(processor, cfg):
logger.info(f"Fetching the dataset: {cfg.dataset_id}")

try:
logger.info("Attempting to load dataset with trust_remote_code=True")
dataset = load_dataset(cfg.dataset_id, trust_remote_code=True)

logger.info(f"Available splits: {list(dataset.keys())}")

if "validation" in dataset:
train_dataset = dataset["train"]
val_dataset = dataset["validation"]
logger.info("Found train and validation splits")
else:
logger.info("No validation split found. Creating 90/10 split from train data.")
train_data = dataset["train"]
train_size = int(0.9 * len(train_data))
val_size = len(train_data) - train_size
train_dataset, val_dataset = random_split(train_data, [train_size, val_size])

except Exception as e:
logger.error(f"FATAL: Failed to load dataset: {e}")
logger.error("Please check the dataset ID in config.py and your internet connection.")
logger.error("If the dataset requires special permissions, ensure you are logged in with `huggingface-cli login`.")
sys.exit(1)

def get_dataloader(processor, cfg, split="train"):
logger.info(f"Fetching the {split} dataset")
train_dataset = load_dataset(cfg.dataset_id, split=split)
train_collate_fn = partial(
train_collate_function, processor=processor, device=cfg.device, transform=get_augmentations(cfg)
)
val_collate_fn = partial(
train_collate_function, processor=processor, device=cfg.device, transform=None
)

logger.info("Building data loader")
train_dataloader = DataLoader(
train_dataset,
batch_size=cfg.batch_size,
collate_fn=train_collate_fn,
shuffle=(split == "train"),
train_dataset, batch_size=cfg.batch_size, collate_fn=train_collate_fn, shuffle=True
)
return train_dataloader
val_dataloader = DataLoader(
val_dataset, batch_size=cfg.batch_size, collate_fn=val_collate_fn, shuffle=False
)
return train_dataloader, val_dataloader


# cross-validation
@torch.no_grad()
def evaluate_model(model, val_dataloader, device):
model.eval()
total_loss = 0
total_count = 0
for batch in val_dataloader:
batch = batch.to(device)
batch = {k: v.to(device) if hasattr(v, 'to') else v for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
total_loss += loss.item() * batch["input_ids"].size(0)
total_count += batch["input_ids"].size(0)

avg_loss = total_loss / total_count if total_count > 0 else 0
model.train()
return avg_loss


def train_model(model, optimizer, cfg, train_dataloader, val_dataloader=None):
logger.info("Start training")
global_step = 0
best_val_loss = float("inf")
os.makedirs(cfg.best_model_output_dir, exist_ok=True)

for epoch in range(cfg.epochs):
logger.info(f"Starting epoch {epoch}")
for idx, batch in enumerate(train_dataloader):
outputs = model(**batch.to(model.device))
batch = {k: v.to(model.device) if hasattr(v, 'to') else v for k, v in batch.items()}

outputs = model(**batch)
loss = outputs.loss
if idx % 100 == 0:
logger.info(f"Epoch: {epoch} Iter: {idx}/{len(train_dataloader)} Loss: {loss.item():.4f}")
wandb.log({"train/loss": loss.item(), "epoch": epoch}, step=global_step)


if idx % 5 == 0:
logger.info(f"Epoch: {epoch}, Step: {idx}, Loss: {loss.item():.4f}")
loss.backward()
optimizer.step()
optimizer.zero_grad()
global_step += 1

# if idx >= 10:
# logger.info("Stopping after 10 steps for testing purposes")
# break

if val_dataloader is not None:
val_loss = evaluate_model(model, val_dataloader, cfg.device)
logger.info("Running validation...")
val_loss = evaluate_model(model, val_dataloader, model.device)
logger.info(f"Epoch: {epoch} Validation Loss: {val_loss:.4f}")
wandb.log({"val/loss": val_loss, "epoch": epoch}, step=global_step)

if val_loss < best_val_loss:
best_val_loss = val_loss
model.save_pretrained(cfg.best_model_output_dir)
logger.info(f"New best model adapter saved to {cfg.best_model_output_dir}")

return model

def set_trainable_params(model, keywords):
for name, param in model.named_parameters():
param.requires_grad = any(k in name for k in keywords)


def run_training_phase(model, processor, cfg, train_dataloader, train_keys, phase_name="phase", val_dataloader=None):
set_trainable_params(model, train_keys)
model.train()
model.to(cfg.device)

params_to_train = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.AdamW(params_to_train, lr=cfg.learning_rate)

wandb.init(
project=cfg.project_name,
name=f"{cfg.run_name}_{phase_name}" if hasattr(cfg, "run_name") else phase_name,
config=vars(cfg),
def print_trainable_parameters(model):
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
logger.info(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.4f}"
)

train_model(model, optimizer, cfg, train_dataloader, val_dataloader)
wandb.finish()

if __name__ == "__main__":
cfg = Configuration()
Expand All @@ -131,34 +165,89 @@ def run_training_phase(model, processor, cfg, train_dataloader, train_keys, phas
if args.checkpoint_id: cfg.checkpoint_id = args.checkpoint_id
if args.attn_imp: cfg.attn_implementation = args.attn_imp

logger.info("="*60)
logger.info("Starting QLoRA Training Test")
logger.info("="*60)

logger.info("Step 1: Loading processor...")
processor = AutoProcessor.from_pretrained(cfg.model_id)
if args.include_loc_tokens:
logger.info("Adding location tokens to the tokenizer")
processor = get_processor_with_new_tokens(processor)

train_dataloader = get_dataloader(processor=processor, cfg=cfg, split="train")
try:
val_dataloader = get_dataloader(processor=processor, cfg=cfg, split="validation")
except ValueError:
logger.warning("No validation split found in the dataset. Validation will be skipped.")
val_dataloader = None
# This function needs to be defined in utils.py
# processor = get_processor_with_new_tokens(processor)
pass
logger.info("Processor loaded.")

logger.info("Step 2: Loading and splitting dataset...")
train_dataloader, val_dataloader = get_dataloaders(processor=processor, cfg=cfg)
logger.info("Dataset loaded and dataloaders created.")

logger.info("Step 3: Loading model with QLoRA...")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)

logger.info("Loading model")
if "SmolVLM" in cfg.model_id:
model = AutoModelForVision2Seq.from_pretrained(cfg.model_id, device_map="auto")
else:
model = AutoModelForCausalLM.from_pretrained(cfg.model_id, torch_dtype=cfg.dtype, device_map="auto", attn_implementation=cfg.attn_implementation)
model = AutoModelForCausalLM.from_pretrained(
cfg.model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation=cfg.attn_implementation,
quantization_config=quantization_config,
trust_remote_code=True
)
logger.info("Base model loaded. Preparing for k-bit training...")

model = prepare_model_for_kbit_training(model)

if cfg.lora_target_modules is None:
import re
pattern = r'self_attn\.(q_proj|k_proj|v_proj|o_proj)$'
lora_target_modules = [name for name, _ in model.named_modules() if re.search(pattern, name)]
if not lora_target_modules:
logger.warning("No modules found with regex, using default module names")
lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
cfg.lora_target_modules = list(set(lora_target_modules))

logger.info(f"Applying LoRA to modules: {cfg.lora_target_modules}")
peft_config = LoraConfig(
lora_alpha=cfg.lora_alpha,
lora_dropout=cfg.lora_dropout,
r=cfg.lora_r,
bias="none",
target_modules=cfg.lora_target_modules,
task_type="CAUSAL_LM"
)
model = get_peft_model(model, peft_config)

if args.include_loc_tokens:
logger.info("Resizing token embeddings for location tokens...")
model = get_model_with_resize_token_embeddings(model, processor)
logger.info("Model ready for training.")

print_trainable_parameters(model)

# To enable experiment tracking, uncomment the following lines and run `wandb login`
# logger.info("Step 4: Setting up optimizer and Weights & Biases...")
# wandb.init(
# project=cfg.project_name,
# name=f"{cfg.model_id.replace('/', '_')}-qlora",
# config=vars(cfg),
# )

logger.info("Step 4: Setting up optimizer...")
model.train()
params_to_train = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.AdamW(params_to_train, lr=cfg.learning_rate)

logger.info("Single-stage: Fine-tuning embed_tokens + attn")
run_training_phase(model, processor, cfg, train_dataloader, train_keys=["embed_tokens", "attn"], phase_name="embed_attn_embed_tokens", val_dataloader=val_dataloader)
else:
logger.info("Single-stage: Fine-tuning attn only")
run_training_phase(model, processor, cfg, train_dataloader, train_keys=["attn"], phase_name="attn_only", val_dataloader=val_dataloader)

model.push_to_hub(cfg.checkpoint_id)
processor.push_to_hub(cfg.checkpoint_id)

logger.info("Train finished")
logger.info("Step 5: Starting training loop...")
train_model(model, optimizer, cfg, train_dataloader, val_dataloader)

# Uncomment the following line to finish logging to W&B
# wandb.finish()
logger.info("="*60)
logger.info("✅ QLoRA TRAINING TEST COMPLETED SUCCESSFULLY!")
logger.info("✅ Model trained with 4-bit quantization!")
logger.info("="*60)