Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit da3c5cc

Browse files
committed
Add lora saving functionality
1 parent 8ed8714 commit da3c5cc

File tree

1 file changed

+59
-8
lines changed

1 file changed

+59
-8
lines changed

train.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
ImageDataset, VideoFolderDataset, CachedDataset
4141
from einops import rearrange, repeat
4242

43-
from lora_diffusion import (
43+
from utils.lora import (
4444
extract_lora_ups_down,
4545
inject_trainable_lora,
4646
inject_trainable_lora_extended,
@@ -271,7 +271,7 @@ def create_optimizer_params(model_list, lr):
271271
# If this is true, we can train it.
272272
if condition:
273273
for n, p in model.named_parameters():
274-
should_negate = negate_params(n, negation)
274+
should_negate = 'lora' in n
275275
if should_negate: continue
276276

277277
params = create_optim_params(n, p, lr, extra_params)
@@ -403,11 +403,6 @@ def should_sample(global_step, validation_steps, validation_data):
403403
return (global_step % validation_steps == 0 or global_step == 1) \
404404
and validation_data.sample_preview
405405

406-
def replace_prompt(prompt, token, wlist):
407-
for w in wlist:
408-
if w in prompt: return prompt.replace(w, token)
409-
return prompt
410-
411406
def save_pipe(
412407
path,
413408
global_step,
@@ -418,6 +413,8 @@ def save_pipe(
418413
output_dir,
419414
use_unet_lora,
420415
use_text_lora,
416+
unet_target_replace_module=None,
417+
text_target_replace_module=None,
421418
is_checkpoint=False
422419
):
423420

@@ -440,7 +437,16 @@ def save_pipe(
440437
vae=vae,
441438
)
442439

443-
handle_lora_save(use_unet_lora, use_text_lora, pipeline, end_train=not is_checkpoint)
440+
handle_lora_save(
441+
use_unet_lora, use_text_lora,
442+
pipeline,
443+
output_dir,
444+
global_step,
445+
unet_target_replace_module,
446+
text_target_replace_module,
447+
end_train=not is_checkpoint
448+
)
449+
444450
pipeline.save_pretrained(save_path)
445451

446452
if is_checkpoint:
@@ -451,6 +457,49 @@ def save_pipe(
451457

452458
del pipeline
453459
torch.cuda.empty_cache()
460+
gc.collect()
461+
462+
463+
def replace_prompt(prompt, token, wlist):
464+
for w in wlist:
465+
if w in prompt: return prompt.replace(w, token)
466+
return prompt
467+
468+
def handle_lora_save(
469+
use_unet_lora,
470+
use_text_lora,
471+
model,
472+
save_path,
473+
checkpoint_step,
474+
unet_target_replace_module=None,
475+
text_target_replace_module=None,
476+
end_train=False
477+
):
478+
if end_train:
479+
if use_unet_lora:
480+
collapse_lora(model.unet)
481+
monkeypatch_remove_lora(model.unet)
482+
483+
if use_text_lora:
484+
collapse_lora(model.text_encoder)
485+
monkeypatch_remove_lora(model.text_encoder)
486+
487+
if not end_train:
488+
save_path = f"{save_path}/lora"
489+
os.makedirs(save_path, exist_ok=True)
490+
491+
if use_unet_lora and unet_target_replace_module is not None:
492+
save_lora_weight(
493+
model.unet,
494+
f"{save_path}/{checkpoint_step}_unet.pt",
495+
unet_target_replace_module
496+
)
497+
if use_text_lora and text_target_replace_module is not None:
498+
save_lora_weight(
499+
model.text_encoder,
500+
f"{save_path}/{checkpoint_step}_text_encoder.pt",
501+
text_target_replace_module
502+
)
454503

455504
def main(
456505
pretrained_model_path: str,
@@ -836,6 +885,8 @@ def finetune_unet(batch, train_encoder=False):
836885
output_dir,
837886
use_unet_lora,
838887
use_text_lora,
888+
unet_target_replace_module=unet_lora_modules,
889+
text_target_replace_module=text_encoder_lora_modules,
839890
is_checkpoint=True
840891
)
841892

0 commit comments

Comments
 (0)