Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6d42662
initial changes [untested]
ajaymin28 Jun 16, 2025
b0b79fd
fixed device map to auto
ajaymin28 Jun 16, 2025
e239b6c
moved root files to modules
ajaymin28 Jun 16, 2025
3311eb6
updated lora training code
ajaymin28 Jun 16, 2025
8356959
updated quantization option for qlora
ajaymin28 Jun 16, 2025
c8eea4e
updated quantization option for qlora
ajaymin28 Jun 16, 2025
d8cebfb
added unsloth changes
ajaymin28 Jun 16, 2025
5c59f8e
fixed tokenizer issue, made img size to small
ajaymin28 Jun 16, 2025
d0bdf5e
unsloth uncommented
ajaymin28 Jun 16, 2025
1514fdb
set quant type to fp4
ajaymin28 Jun 16, 2025
5fd67e4
cfg print
ajaymin28 Jun 16, 2025
edb273e
wandb disable
ajaymin28 Jun 16, 2025
f9abfb0
workig code for unsloth
ajaymin28 Jun 16, 2025
3ed763b
final checkpoint for usloth training [uncleaned]
ajaymin28 Jun 16, 2025
56b25d0
final checkpoint for usloth training [uncleaned]
ajaymin28 Jun 16, 2025
dcfd2d7
lad best model back and push to hub
ajaymin28 Jun 16, 2025
c40dcbf
tested code train/val, save models
ajaymin28 Jun 16, 2025
9146526
Merge pull request #1 from ajaymin28/LoraTraining
ajaymin28 Jun 17, 2025
fd31e7c
cleaned code
ajaymin28 Jun 17, 2025
0edd0c6
added TODOs for cleanup and new features
ajaymin28 Jun 17, 2025
ef628bb
working code for unsloth qlora, vanilla qlora on l4
ajaymin28 Jun 17, 2025
de70910
updated requirements
ajaymin28 Jun 17, 2025
d5fd4e0
cleaned code, enable wandb, tested google gemma model
ajaymin28 Jun 17, 2025
95f62ce
Merge pull request #2 from ajaymin28/LoraTraining
ajaymin28 Jun 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions config.py

This file was deleted.

25 changes: 25 additions & 0 deletions configs/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
dataset_id: "ariG23498/license-detection-paligemma"
model_id: "google/gemma-3-4b-pt" # "unsloth/gemma-3-4b-it"
checkpoint_id: "ajaymin28/Gemma3_ObjeDet"

device: "cuda"
dtype: "bfloat16"

batch_size: 16
learning_rate: 2e-5
epochs: 1
max_step_to_train: 100
validate_steps_freq: 10

finetune_method: "qlora" # FFT | lora | qlora
use_unsloth: false


mm_tunable_parts:
- no_exist_layer # basically not finetuning any base components
# - mlp
# - multi_modal_projector
# - vision_tower
# - language_model
wandb_project_name: "Gemma3_LoRA"
push_model_to_hub: true
17 changes: 17 additions & 0 deletions configs/lora_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
r: 32
alpha: 32
dropout: 0.05
target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- up_proj
- down_proj
- gate_proj
max_seq_length: 2048 # Unsloth will RoPE-scale

# LoRA-specific: no quantization
load_in_4bit: false
load_in_8bit: false
quantization_config: null
20 changes: 20 additions & 0 deletions configs/qlora_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
r: 32
alpha: 32
dropout: 0.05
target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- up_proj
- down_proj
- gate_proj
max_seq_length: 2048 # Unsloth will RoPE-scale

# QLoRA-specific: quantization enabled
load_in_4bit: true
load_in_8bit: false
quantization_config:
bnb_4bit_use_double_quant: true
bnb_4bit_quant_type: "nf4"
bnb_4bit_compute_dtype: "bfloat16"
6 changes: 3 additions & 3 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from torch.utils.data import DataLoader
from transformers import AutoProcessor, Gemma3ForConditionalGeneration

from config import Configuration
from utils import test_collate_function, visualize_bounding_boxes
from utils.config import Configuration
from utils.utilities import test_collate_function, visualize_bounding_boxes

os.makedirs("outputs", exist_ok=True)

Expand All @@ -23,7 +23,7 @@ def get_dataloader(processor):


if __name__ == "__main__":
cfg = Configuration()
cfg = Configuration.from_args()
processor = AutoProcessor.from_pretrained(cfg.checkpoint_id)
model = Gemma3ForConditionalGeneration.from_pretrained(
cfg.checkpoint_id,
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ wandb
peft
albumentations
argparse
omegaconf
unsloth==2025.5.7
unsloth-zoo==2025.5.8
Loading