From 12b9c6ae7589f0d209560d717a75d3126ac0799e Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Mon, 14 Jul 2025 16:50:27 +0200 Subject: [PATCH] Clone the layer to break shared weight and save it using safetensors --- train.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/train.py b/train.py index 07a661c..9b135a8 100644 --- a/train.py +++ b/train.py @@ -3,6 +3,7 @@ from functools import partial import torch +import torch.nn as nn from datasets import load_dataset from torch.utils.data import DataLoader from transformers import AutoProcessor, AutoModelForVision2Seq, AutoModelForCausalLM @@ -157,6 +158,13 @@ def run_training_phase(model, processor, cfg, train_dataloader, train_keys, phas else: logger.info("Single-stage: Fine-tuning attn only") run_training_phase(model, processor, cfg, train_dataloader, train_keys=["attn"], phase_name="attn_only", val_dataloader=val_dataloader) + + try: + if model.lm_head.weight.data_ptr() == model.language_model.embed_tokens.weight.data_ptr(): + logger.info("Cloning lm_head to break shared weights...") + model.lm_head.weight = nn.Parameter(model.lm_head.weight.clone()) + except Exception as ex: + logger.info("Either requested layers not found or problem during cloning operation.") model.push_to_hub(cfg.checkpoint_id) processor.push_to_hub(cfg.checkpoint_id)