Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit 139ab1d

Browse files
committed
Add primary logic for full lora training
1 parent 6b77294 commit 139ab1d

File tree

1 file changed

+102
-39
lines changed

1 file changed

+102
-39
lines changed

train.py

Lines changed: 102 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import random
88
import gc
9+
import copy
910

1011
from typing import Dict, Optional, Tuple
1112
from omegaconf import OmegaConf
@@ -44,11 +45,10 @@
4445
extract_lora_ups_down,
4546
inject_trainable_lora,
4647
inject_trainable_lora_extended,
47-
safetensors_available,
4848
save_lora_weight,
49-
save_safeloras,
50-
collapse_lora,
51-
monkeypatch_remove_lora
49+
train_patch_pipe,
50+
monkeypatch_or_replace_lora,
51+
monkeypatch_or_replace_lora_extended
5252
)
5353

5454

@@ -180,7 +180,7 @@ def handle_memory_attention(enable_xformers_memory_efficient_attention, enable_t
180180
except:
181181
print("Could not enable memory efficient attention for xformers or Torch 2.0.")
182182

183-
def inject_lora(use_lora, model, replace_modules, is_extended=False, rank: int = 16):
183+
def inject_lora(use_lora, model, replace_modules, is_extended=False, dropout=0.0, lora_path='', r=128):
184184
injector = (
185185
inject_trainable_lora if not is_extended
186186
else
@@ -190,15 +190,44 @@ def inject_lora(use_lora, model, replace_modules, is_extended=False, rank: int =
190190
params = None
191191
negation = None
192192

193+
if os.path.exists(lora_path):
194+
try:
195+
for f in os.listdir(lora_path):
196+
if f.endswith('.pt'):
197+
lora_file = os.path.join(lora_path, f)
198+
199+
if 'text_encoder' in f and isinstance(model, CLIPTextModel):
200+
monkeypatch_or_replace_lora(
201+
model,
202+
torch.load(lora_file),
203+
target_replace_module=replace_modules,
204+
r=r
205+
)
206+
print("Successfully loaded Text Encoder LoRa.")
207+
208+
if 'unet' in f and isinstance(model, UNet3DConditionModel):
209+
monkeypatch_or_replace_lora_extended(
210+
model,
211+
torch.load(lora_file),
212+
target_replace_module=replace_modules,
213+
r=r
214+
)
215+
print("Successfully loaded UNET LoRa.")
216+
217+
except Exception as e:
218+
print(e)
219+
print("Could not load LoRAs. Injecting new ones instead...")
220+
193221
if use_lora:
194222
REPLACE_MODULES = replace_modules
195-
196-
params, negation = injector(
197-
model,
198-
target_replace_module=REPLACE_MODULES,
199-
r=rank
200-
)
201-
223+
injector_args = {
224+
"model": model,
225+
"target_replace_module": REPLACE_MODULES,
226+
"r": 128
227+
}
228+
if not is_extended: injector_args['dropout_p'] = dropout
229+
230+
params, negation = injector(**injector_args)
202231
for _up, _down in extract_lora_ups_down(
203232
model,
204233
target_replace_module=REPLACE_MODULES):
@@ -210,15 +239,42 @@ def inject_lora(use_lora, model, replace_modules, is_extended=False, rank: int =
210239

211240
return params, negation
212241

213-
def handle_lora_save(use_unet_lora, use_text_lora, model, end_train=False):
214-
if end_train:
215-
if use_unet_lora:
216-
collapse_lora(model.unet)
217-
monkeypatch_remove_lora(model.unet)
218-
219-
if use_text_lora:
220-
collapse_lora(model.text_encoder)
221-
monkeypatch_remove_lora(model.text_encoder)
242+
def save_lora(model, name, condition, replace_modules, step, save_path):
243+
if condition and replace_modules is not None:
244+
save_path = f"{save_path}/{step}_{name}.pt"
245+
save_lora_weight(model, save_path, replace_modules)
246+
247+
def handle_lora_save(
248+
use_unet_lora,
249+
use_text_lora,
250+
model,
251+
save_path,
252+
checkpoint_step,
253+
unet_target_modules,
254+
text_encoder_target_modules
255+
):
256+
257+
save_path = f"{save_path}/lora"
258+
os.makedirs(save_path, exist_ok=True)
259+
260+
save_lora(
261+
model.unet,
262+
'unet',
263+
use_unet_lora,
264+
unet_target_modules,
265+
checkpoint_step,
266+
save_path,
267+
)
268+
save_lora(
269+
model.text_encoder,
270+
'text_encoder',
271+
use_text_lora,
272+
text_encoder_target_modules,
273+
checkpoint_step,
274+
save_path
275+
)
276+
277+
train_patch_pipe(model, use_unet_lora, use_text_lora)
222278

223279
def param_optim(model, condition, extra_params=None, is_lora=False, negation=None):
224280
return {
@@ -415,7 +471,7 @@ def save_pipe(
415471
use_text_lora,
416472
unet_target_replace_module=None,
417473
text_target_replace_module=None,
418-
is_checkpoint=False
474+
is_checkpoint=False,
419475
):
420476

421477
if is_checkpoint:
@@ -427,35 +483,39 @@ def save_pipe(
427483
# Save the dtypes so we can continue training at the same precision.
428484
u_dtype, t_dtype, v_dtype = unet.dtype, text_encoder.dtype, vae.dtype
429485

430-
# We do this to prevent OOM during training when saving a checkpoint.
431-
[x.to('cpu') for x in [unet, text_encoder, vae]]
486+
# Copy the model without creating a reference to it. This allows keeping the state of our lora training if enabled.
487+
unet_out = copy.deepcopy(accelerator.unwrap_model(unet, keep_fp32_wrapper=False))
488+
text_encoder_out = copy.deepcopy(accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False))
432489

433490
pipeline = TextToVideoSDPipeline.from_pretrained(
434491
path,
435-
unet=unet,
436-
text_encoder=text_encoder,
492+
unet=unet_out,
493+
text_encoder=text_encoder_out,
437494
vae=vae,
438-
)
495+
).to(torch_dtype=torch.float16)
439496

440497
handle_lora_save(
441-
use_unet_lora, use_text_lora,
498+
use_unet_lora,
499+
use_text_lora,
442500
pipeline,
443-
output_dir,
501+
output_dir,
444502
global_step,
445-
unet_target_replace_module,
446-
text_target_replace_module,
447-
end_train=not is_checkpoint
503+
unet_target_replace_module,
504+
text_target_replace_module
448505
)
449506

450507
pipeline.save_pretrained(save_path)
451508

452509
if is_checkpoint:
510+
unet, text_encoder = accelerator.prepare(unet, text_encoder)
453511
models_to_cast_back = [(unet, u_dtype), (text_encoder, t_dtype), (vae, v_dtype)]
454512
[x[0].to(accelerator.device, dtype=x[1]) for x in models_to_cast_back]
455513

456514
logger.info(f"Saved model at {save_path} on step {global_step}")
457515

458516
del pipeline
517+
del unet_out
518+
del text_encoder_out
459519
torch.cuda.empty_cache()
460520
gc.collect()
461521

@@ -544,6 +604,7 @@ def main(
544604
unet_lora_modules: Tuple[str] = ["ResnetBlock2D"],
545605
text_encoder_lora_modules: Tuple[str] = ["CLIPEncoderLayer"],
546606
lora_rank: int = 16,
607+
lora_path: str = '',
547608
**kwargs
548609
):
549610

@@ -590,12 +651,12 @@ def main(
590651
# Use LoRA if enabled.
591652
unet_lora_params, unet_negation = inject_lora(
592653
use_unet_lora, unet, unet_lora_modules, is_extended=True,
593-
rank=lora_rank
654+
r=lora_rank, lora_path=lora_path
594655
)
595656

596657
text_encoder_lora_params, text_encoder_negation = inject_lora(
597658
use_text_lora, text_encoder, text_encoder_lora_modules,
598-
rank=lora_rank
659+
r=lora_rank, lora_path=lora_path
599660
)
600661

601662
# Create parameters to optimize over with a condition (if "condition" is true, optimize it)
@@ -604,8 +665,8 @@ def main(
604665
param_optim(text_encoder, train_text_encoder and not use_text_lora, extra_params=extra_text_encoder_params,
605666
negation=text_encoder_negation
606667
),
607-
param_optim(text_encoder_lora_params, use_text_lora, is_lora=True, extra_params={"lr": 5e-5}),
608-
param_optim(unet_lora_params, use_unet_lora, is_lora=True, extra_params={"lr": 5e-6})
668+
param_optim(text_encoder_lora_params, use_text_lora, is_lora=True, extra_params={"lr": 1e-5}),
669+
param_optim(unet_lora_params, use_unet_lora, is_lora=True, extra_params={"lr": 1e-5})
609670
]
610671

611672
params = create_optimizer_params(optim_params, learning_rate)
@@ -885,8 +946,8 @@ def finetune_unet(batch, train_encoder=False):
885946
output_dir,
886947
use_unet_lora,
887948
use_text_lora,
888-
unet_target_replace_module=unet_lora_modules,
889-
text_target_replace_module=text_encoder_lora_modules,
949+
unet_lora_modules,
950+
text_encoder_lora_modules,
890951
is_checkpoint=True
891952
)
892953

@@ -959,8 +1020,10 @@ def finetune_unet(batch, train_encoder=False):
9591020
output_dir,
9601021
use_unet_lora,
9611022
use_text_lora,
1023+
unet_lora_modules,
1024+
text_encoder_lora_modules,
9621025
is_checkpoint=False
963-
)
1026+
)
9641027
accelerator.end_training()
9651028

9661029
if __name__ == "__main__":

0 commit comments

Comments
 (0)