4040 ImageDataset , VideoFolderDataset , CachedDataset
4141from einops import rearrange , repeat
4242
43- from lora_diffusion import (
43+ from utils . lora import (
4444 extract_lora_ups_down ,
4545 inject_trainable_lora ,
4646 inject_trainable_lora_extended ,
@@ -271,7 +271,7 @@ def create_optimizer_params(model_list, lr):
271271 # If this is true, we can train it.
272272 if condition :
273273 for n , p in model .named_parameters ():
274- should_negate = negate_params ( n , negation )
274+ should_negate = 'lora' in n
275275 if should_negate : continue
276276
277277 params = create_optim_params (n , p , lr , extra_params )
@@ -403,11 +403,6 @@ def should_sample(global_step, validation_steps, validation_data):
403403 return (global_step % validation_steps == 0 or global_step == 1 ) \
404404 and validation_data .sample_preview
405405
406- def replace_prompt (prompt , token , wlist ):
407- for w in wlist :
408- if w in prompt : return prompt .replace (w , token )
409- return prompt
410-
411406def save_pipe (
412407 path ,
413408 global_step ,
@@ -418,6 +413,8 @@ def save_pipe(
418413 output_dir ,
419414 use_unet_lora ,
420415 use_text_lora ,
416+ unet_target_replace_module = None ,
417+ text_target_replace_module = None ,
421418 is_checkpoint = False
422419 ):
423420
@@ -440,7 +437,16 @@ def save_pipe(
440437 vae = vae ,
441438 )
442439
443- handle_lora_save (use_unet_lora , use_text_lora , pipeline , end_train = not is_checkpoint )
440+ handle_lora_save (
441+ use_unet_lora , use_text_lora ,
442+ pipeline ,
443+ output_dir ,
444+ global_step ,
445+ unet_target_replace_module ,
446+ text_target_replace_module ,
447+ end_train = not is_checkpoint
448+ )
449+
444450 pipeline .save_pretrained (save_path )
445451
446452 if is_checkpoint :
@@ -451,6 +457,49 @@ def save_pipe(
451457
452458 del pipeline
453459 torch .cuda .empty_cache ()
460+ gc .collect ()
461+
462+
463+ def replace_prompt (prompt , token , wlist ):
464+ for w in wlist :
465+ if w in prompt : return prompt .replace (w , token )
466+ return prompt
467+
468+ def handle_lora_save (
469+ use_unet_lora ,
470+ use_text_lora ,
471+ model ,
472+ save_path ,
473+ checkpoint_step ,
474+ unet_target_replace_module = None ,
475+ text_target_replace_module = None ,
476+ end_train = False
477+ ):
478+ if end_train :
479+ if use_unet_lora :
480+ collapse_lora (model .unet )
481+ monkeypatch_remove_lora (model .unet )
482+
483+ if use_text_lora :
484+ collapse_lora (model .text_encoder )
485+ monkeypatch_remove_lora (model .text_encoder )
486+
487+ if not end_train :
488+ save_path = f"{ save_path } /lora"
489+ os .makedirs (save_path , exist_ok = True )
490+
491+ if use_unet_lora and unet_target_replace_module is not None :
492+ save_lora_weight (
493+ model .unet ,
494+ f"{ save_path } /{ checkpoint_step } _unet.pt" ,
495+ unet_target_replace_module
496+ )
497+ if use_text_lora and text_target_replace_module is not None :
498+ save_lora_weight (
499+ model .text_encoder ,
500+ f"{ save_path } /{ checkpoint_step } _text_encoder.pt" ,
501+ text_target_replace_module
502+ )
454503
455504def main (
456505 pretrained_model_path : str ,
@@ -836,6 +885,8 @@ def finetune_unet(batch, train_encoder=False):
836885 output_dir ,
837886 use_unet_lora ,
838887 use_text_lora ,
888+ unet_target_replace_module = unet_lora_modules ,
889+ text_target_replace_module = text_encoder_lora_modules ,
839890 is_checkpoint = True
840891 )
841892
0 commit comments