From 8016219829243df8067a26a615f7aa9f8abd02c6 Mon Sep 17 00:00:00 2001 From: Mohammed Amine Jebbar Date: Tue, 27 May 2025 05:30:01 +0200 Subject: [PATCH 1/2] Feature Add QLoRa --- config.py | 50 +++++++++++++++++++--- predict_qlora.py | 94 +++++++++++++++++++++++++++++++++++++++++ train.py | 108 ++++++++++++++++++++++++++++++++--------------- 3 files changed, 212 insertions(+), 40 deletions(-) create mode 100644 predict_qlora.py diff --git a/config.py b/config.py index 5461122..a683379 100644 --- a/config.py +++ b/config.py @@ -1,18 +1,56 @@ from dataclasses import dataclass - import torch +from transformers import BitsAndBytesConfig +from peft import LoraConfig @dataclass class Configuration: + # Identifiants dataset_id: str = "ariG23498/license-detection-paligemma" - model_id: str = "google/gemma-3-4b-pt" - checkpoint_id: str = "sergiopaniego/gemma-3-4b-pt-object-detection-aug" + checkpoint_id: str = "sergiopaniego/gemma-3-4b-pt-object-detection-qlora" + # Infos projet (ajouté pour wandb) + project_name: str = "gemma3-detection" + run_name: str = "run-qlora" + + # Entraînement device: str = "cuda" if torch.cuda.is_available() else "cpu" dtype: torch.dtype = torch.bfloat16 - batch_size: int = 8 - learning_rate: float = 2e-05 - epochs = 2 + learning_rate: float = 2e-5 + epochs: int = 2 + + # Activation QLoRA + use_qlora: bool = True + + # Paramètres LoRA + lora_r: int = 16 + lora_alpha: int = 32 + lora_dropout: float = 0.1 + + @property + def bnb_config(self): + """Configuration de quantification 4-bit pour QLoRA""" + if not self.use_qlora: + return None + return BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=self.dtype, + bnb_4bit_use_double_quant=True, + ) + + @property + def lora_config(self): + """Configuration LoRA utilisée dans le setup du modèle""" + if not self.use_qlora: + return None + return LoraConfig( + r=self.lora_r, + lora_alpha=self.lora_alpha, + lora_dropout=self.lora_dropout, + bias="none", + task_type=None # sera défini dans train.py selon le modèle + ) diff --git a/predict_qlora.py b/predict_qlora.py new file mode 100644 index 0000000..098d249 --- /dev/null +++ b/predict_qlora.py @@ -0,0 +1,94 @@ +import os +from functools import partial + +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import AutoProcessor, Gemma3ForConditionalGeneration +from peft import PeftModel + +from config import Configuration +from utils import test_collate_function, visualize_bounding_boxes + +os.makedirs("outputs", exist_ok=True) + + +def get_dataloader(processor): + test_dataset = load_dataset(cfg.dataset_id, split="test") + test_collate_fn = partial( + test_collate_function, processor=processor, dtype=cfg.dtype + ) + test_dataloader = DataLoader( + test_dataset, batch_size=cfg.batch_size, collate_fn=test_collate_fn + ) + return test_dataloader + + +def load_model_for_inference(cfg): + """Charge le modèle pour l'inférence selon la configuration""" + + if cfg.use_qlora: + # Charger le modèle de base avec quantification + print("Loading base model with quantization...") + base_model = Gemma3ForConditionalGeneration.from_pretrained( + cfg.model_id, + torch_dtype=cfg.dtype, + device_map="auto", + quantization_config=cfg.bnb_config, + trust_remote_code=True, + ) + + # Charger les adaptateurs LoRA + print("Loading LoRA adapters...") + model = PeftModel.from_pretrained(base_model, cfg.checkpoint_id) + print("Model loaded with QLoRA adapters") + + else: + # Mode traditionnel : charger le modèle complet + print("Loading full fine-tuned model...") + model = Gemma3ForConditionalGeneration.from_pretrained( + cfg.checkpoint_id, + torch_dtype=cfg.dtype, + device_map="auto", + ) + + return model + + +if __name__ == "__main__": + cfg = Configuration() + + # Charger le processeur + processor = AutoProcessor.from_pretrained( + cfg.checkpoint_id if not cfg.use_qlora else cfg.model_id + ) + + # Charger le modèle selon la configuration + model = load_model_for_inference(cfg) + model.eval() + + # Préparer les données de test + test_dataloader = get_dataloader(processor=processor) + sample, sample_images = next(iter(test_dataloader)) + + # Déplacer sur le bon device + sample = {k: v.to(model.device) if hasattr(v, 'to') else v for k, v in sample.items()} + + # Génération + print("Generating predictions...") + generation = model.generate(**sample, max_new_tokens=100, do_sample=False) + decoded = processor.batch_decode(generation, skip_special_tokens=True) + + # Visualisation des résultats + file_count = 0 + for output_text, sample_image in zip(decoded, sample_images): + image = sample_image[0] + width, height = image.size + + print(f"Generated text for image {file_count}: {output_text}") + + visualize_bounding_boxes( + image, output_text, width, height, f"outputs/output_{file_count}.png" + ) + file_count += 1 + + print(f"Generated {file_count} predictions in outputs/ directory") \ No newline at end of file diff --git a/train.py b/train.py index f9c7e00..8ffbeef 100644 --- a/train.py +++ b/train.py @@ -6,18 +6,20 @@ from datasets import load_dataset from torch.utils.data import DataLoader from transformers import AutoProcessor, Gemma3ForConditionalGeneration +from peft import get_peft_model, prepare_model_for_kbit_training from config import Configuration from utils import train_collate_function import albumentations as A +# Setup logger logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) - +# Define augmentations augmentations = A.Compose([ A.Resize(height=896, width=896), A.HorizontalFlip(p=0.5), @@ -25,14 +27,14 @@ ], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids'], filter_invalid_bboxes=True)) -def get_dataloader(processor): - logger.info("Fetching the dataset") +def get_dataloader(processor, cfg): + logger.info("Loading dataset") train_dataset = load_dataset(cfg.dataset_id, split="train") train_collate_fn = partial( train_collate_function, processor=processor, dtype=cfg.dtype, transform=augmentations ) - logger.info("Building data loader") + logger.info("Building DataLoader") train_dataloader = DataLoader( train_dataset, batch_size=cfg.batch_size, @@ -42,60 +44,98 @@ def get_dataloader(processor): return train_dataloader +def setup_model(cfg): + logger.info("Loading model with QLoRA configuration") + + model = Gemma3ForConditionalGeneration.from_pretrained( + cfg.model_id, + torch_dtype=cfg.dtype, + device_map="auto", + attn_implementation="eager", + quantization_config=cfg.bnb_config if cfg.use_qlora else None, + trust_remote_code=True, + ) + + if cfg.use_qlora: + logger.info("Preparing model for QLoRA training") + model = prepare_model_for_kbit_training(model) + + lora_config = cfg.lora_config + lora_config.target_modules = [ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", + ] + lora_config.task_type = "CAUSAL_LM" + + model = get_peft_model(model, lora_config) + model.print_trainable_parameters() + else: + logger.info("Traditional mode - training attention layers only") + for name, param in model.named_parameters(): + param.requires_grad = "attn" in name + + return model + + def train_model(model, optimizer, cfg, train_dataloader): - logger.info("Start training") + logger.info("Starting training") global_step = 0 + for epoch in range(cfg.epochs): for idx, batch in enumerate(train_dataloader): - outputs = model(**batch.to(model.device)) - loss = outputs.loss + # Move data to device + batch = {k: v.to(model.device) if hasattr(v, 'to') else v for k, v in batch.items()} + + outputs = model(**batch) + loss = outputs.loss if hasattr(outputs, "loss") else outputs[0] + if idx % 100 == 0: - logger.info(f"Epoch: {epoch} Iter: {idx} Loss: {loss.item():.4f}") + logger.info(f"Epoch {epoch} | Step {idx} | Loss: {loss.item():.4f}") wandb.log({"train/loss": loss.item(), "epoch": epoch}, step=global_step) loss.backward() optimizer.step() optimizer.zero_grad() global_step += 1 + return model if __name__ == "__main__": cfg = Configuration() - processor = AutoProcessor.from_pretrained(cfg.model_id) - train_dataloader = get_dataloader(processor) - logger.info("Getting model & turning only attention parameters to trainable") - model = Gemma3ForConditionalGeneration.from_pretrained( - cfg.model_id, - torch_dtype=cfg.dtype, - device_map="cpu", - attn_implementation="eager", + # Initialize Weights & Biases + wandb.init( + project=cfg.project_name if hasattr(cfg, "project_name") else "gemma3-detection", + name=cfg.run_name if hasattr(cfg, "run_name") else "run-qlora" if cfg.use_qlora else "run-traditional", + config=vars(cfg), ) - for name, param in model.named_parameters(): - if "attn" in name: - param.requires_grad = True - else: - param.requires_grad = False + # Preprocessing + processor = AutoProcessor.from_pretrained(cfg.model_id) + train_dataloader = get_dataloader(processor, cfg) + + # Load model + model = setup_model(cfg) model.train() - model.to(cfg.device) - # Credits to Sayak Paul for this beautiful expression - params_to_train = list(filter(lambda x: x.requires_grad, model.parameters())) - optimizer = torch.optim.AdamW(params_to_train, lr=cfg.learning_rate) + # Optimizer + trainable_params = [p for p in model.parameters() if p.requires_grad] + logger.info(f"Number of trainable parameters: {sum(p.numel() for p in trainable_params):,}") + optimizer = torch.optim.AdamW(trainable_params, lr=cfg.learning_rate) - wandb.init( - project=cfg.project_name, - name=cfg.run_name if hasattr(cfg, "run_name") else None, - config=vars(cfg), - ) + # Training + trained_model = train_model(model, optimizer, cfg, train_dataloader) - train_model(model, optimizer, cfg, train_dataloader) + # Save + logger.info("Saving model") + if cfg.use_qlora: + trained_model.save_pretrained(cfg.checkpoint_id) + trained_model.push_to_hub(cfg.checkpoint_id) + else: + model.push_to_hub(cfg.checkpoint_id) - # Push the checkpoint to hub - model.push_to_hub(cfg.checkpoint_id) processor.push_to_hub(cfg.checkpoint_id) wandb.finish() - logger.info("Train finished") + logger.info("Training complete") From f624d29f9b213438dc3f9462081de382f761d631 Mon Sep 17 00:00:00 2001 From: Mohammed Amine Jebbar Date: Tue, 27 May 2025 05:35:53 +0200 Subject: [PATCH 2/2] Editing comments --- config.py | 16 ++++++++-------- predict_qlora.py | 42 +++++++++++++++++++++--------------------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/config.py b/config.py index a683379..f7ace32 100644 --- a/config.py +++ b/config.py @@ -6,33 +6,33 @@ @dataclass class Configuration: - # Identifiants + # Identifiers dataset_id: str = "ariG23498/license-detection-paligemma" model_id: str = "google/gemma-3-4b-pt" checkpoint_id: str = "sergiopaniego/gemma-3-4b-pt-object-detection-qlora" - # Infos projet (ajouté pour wandb) + # Project info (added for wandb) project_name: str = "gemma3-detection" run_name: str = "run-qlora" - # Entraînement + # Training device: str = "cuda" if torch.cuda.is_available() else "cpu" dtype: torch.dtype = torch.bfloat16 batch_size: int = 8 learning_rate: float = 2e-5 epochs: int = 2 - # Activation QLoRA + # QLoRA activation use_qlora: bool = True - # Paramètres LoRA + # LoRA parameters lora_r: int = 16 lora_alpha: int = 32 lora_dropout: float = 0.1 @property def bnb_config(self): - """Configuration de quantification 4-bit pour QLoRA""" + """4-bit quantization configuration for QLoRA""" if not self.use_qlora: return None return BitsAndBytesConfig( @@ -44,7 +44,7 @@ def bnb_config(self): @property def lora_config(self): - """Configuration LoRA utilisée dans le setup du modèle""" + """LoRA configuration used during model setup""" if not self.use_qlora: return None return LoraConfig( @@ -52,5 +52,5 @@ def lora_config(self): lora_alpha=self.lora_alpha, lora_dropout=self.lora_dropout, bias="none", - task_type=None # sera défini dans train.py selon le modèle + task_type=None # will be set in train.py based on the model ) diff --git a/predict_qlora.py b/predict_qlora.py index 098d249..5aff6f1 100644 --- a/predict_qlora.py +++ b/predict_qlora.py @@ -24,10 +24,10 @@ def get_dataloader(processor): def load_model_for_inference(cfg): - """Charge le modèle pour l'inférence selon la configuration""" - + """Loads the model for inference based on the configuration""" + if cfg.use_qlora: - # Charger le modèle de base avec quantification + # Load the base model with quantization print("Loading base model with quantization...") base_model = Gemma3ForConditionalGeneration.from_pretrained( cfg.model_id, @@ -36,59 +36,59 @@ def load_model_for_inference(cfg): quantization_config=cfg.bnb_config, trust_remote_code=True, ) - - # Charger les adaptateurs LoRA + + # Load LoRA adapters print("Loading LoRA adapters...") model = PeftModel.from_pretrained(base_model, cfg.checkpoint_id) print("Model loaded with QLoRA adapters") - + else: - # Mode traditionnel : charger le modèle complet + # Traditional mode: load the fully fine-tuned model print("Loading full fine-tuned model...") model = Gemma3ForConditionalGeneration.from_pretrained( cfg.checkpoint_id, torch_dtype=cfg.dtype, device_map="auto", ) - + return model if __name__ == "__main__": cfg = Configuration() - - # Charger le processeur + + # Load the processor processor = AutoProcessor.from_pretrained( cfg.checkpoint_id if not cfg.use_qlora else cfg.model_id ) - - # Charger le modèle selon la configuration + + # Load the model based on the configuration model = load_model_for_inference(cfg) model.eval() - # Préparer les données de test + # Prepare test data test_dataloader = get_dataloader(processor=processor) sample, sample_images = next(iter(test_dataloader)) - - # Déplacer sur le bon device + + # Move data to the correct device sample = {k: v.to(model.device) if hasattr(v, 'to') else v for k, v in sample.items()} - # Génération + # Generation print("Generating predictions...") generation = model.generate(**sample, max_new_tokens=100, do_sample=False) decoded = processor.batch_decode(generation, skip_special_tokens=True) - # Visualisation des résultats + # Visualize results file_count = 0 for output_text, sample_image in zip(decoded, sample_images): image = sample_image[0] width, height = image.size - + print(f"Generated text for image {file_count}: {output_text}") - + visualize_bounding_boxes( image, output_text, width, height, f"outputs/output_{file_count}.png" ) file_count += 1 - - print(f"Generated {file_count} predictions in outputs/ directory") \ No newline at end of file + + print(f"Generated {file_count} predictions in outputs/ directory")