66import os
77import random
88import gc
9+ import copy
910
1011from typing import Dict , Optional , Tuple
1112from omegaconf import OmegaConf
4445 extract_lora_ups_down ,
4546 inject_trainable_lora ,
4647 inject_trainable_lora_extended ,
47- safetensors_available ,
4848 save_lora_weight ,
49- save_safeloras ,
50- collapse_lora ,
51- monkeypatch_remove_lora
49+ train_patch_pipe ,
50+ monkeypatch_or_replace_lora ,
51+ monkeypatch_or_replace_lora_extended
5252)
5353
5454
@@ -180,7 +180,7 @@ def handle_memory_attention(enable_xformers_memory_efficient_attention, enable_t
180180 except :
181181 print ("Could not enable memory efficient attention for xformers or Torch 2.0." )
182182
183- def inject_lora (use_lora , model , replace_modules , is_extended = False , rank : int = 16 ):
183+ def inject_lora (use_lora , model , replace_modules , is_extended = False , dropout = 0.0 , lora_path = '' , r = 128 ):
184184 injector = (
185185 inject_trainable_lora if not is_extended
186186 else
@@ -190,15 +190,44 @@ def inject_lora(use_lora, model, replace_modules, is_extended=False, rank: int =
190190 params = None
191191 negation = None
192192
193+ if os .path .exists (lora_path ):
194+ try :
195+ for f in os .listdir (lora_path ):
196+ if f .endswith ('.pt' ):
197+ lora_file = os .path .join (lora_path , f )
198+
199+ if 'text_encoder' in f and isinstance (model , CLIPTextModel ):
200+ monkeypatch_or_replace_lora (
201+ model ,
202+ torch .load (lora_file ),
203+ target_replace_module = replace_modules ,
204+ r = r
205+ )
206+ print ("Successfully loaded Text Encoder LoRa." )
207+
208+ if 'unet' in f and isinstance (model , UNet3DConditionModel ):
209+ monkeypatch_or_replace_lora_extended (
210+ model ,
211+ torch .load (lora_file ),
212+ target_replace_module = replace_modules ,
213+ r = r
214+ )
215+ print ("Successfully loaded UNET LoRa." )
216+
217+ except Exception as e :
218+ print (e )
219+ print ("Could not load LoRAs. Injecting new ones instead..." )
220+
193221 if use_lora :
194222 REPLACE_MODULES = replace_modules
195-
196- params , negation = injector (
197- model ,
198- target_replace_module = REPLACE_MODULES ,
199- r = rank
200- )
201-
223+ injector_args = {
224+ "model" : model ,
225+ "target_replace_module" : REPLACE_MODULES ,
226+ "r" : 128
227+ }
228+ if not is_extended : injector_args ['dropout_p' ] = dropout
229+
230+ params , negation = injector (** injector_args )
202231 for _up , _down in extract_lora_ups_down (
203232 model ,
204233 target_replace_module = REPLACE_MODULES ):
@@ -210,15 +239,42 @@ def inject_lora(use_lora, model, replace_modules, is_extended=False, rank: int =
210239
211240 return params , negation
212241
213- def handle_lora_save (use_unet_lora , use_text_lora , model , end_train = False ):
214- if end_train :
215- if use_unet_lora :
216- collapse_lora (model .unet )
217- monkeypatch_remove_lora (model .unet )
218-
219- if use_text_lora :
220- collapse_lora (model .text_encoder )
221- monkeypatch_remove_lora (model .text_encoder )
242+ def save_lora (model , name , condition , replace_modules , step , save_path ):
243+ if condition and replace_modules is not None :
244+ save_path = f"{ save_path } /{ step } _{ name } .pt"
245+ save_lora_weight (model , save_path , replace_modules )
246+
247+ def handle_lora_save (
248+ use_unet_lora ,
249+ use_text_lora ,
250+ model ,
251+ save_path ,
252+ checkpoint_step ,
253+ unet_target_modules ,
254+ text_encoder_target_modules
255+ ):
256+
257+ save_path = f"{ save_path } /lora"
258+ os .makedirs (save_path , exist_ok = True )
259+
260+ save_lora (
261+ model .unet ,
262+ 'unet' ,
263+ use_unet_lora ,
264+ unet_target_modules ,
265+ checkpoint_step ,
266+ save_path ,
267+ )
268+ save_lora (
269+ model .text_encoder ,
270+ 'text_encoder' ,
271+ use_text_lora ,
272+ text_encoder_target_modules ,
273+ checkpoint_step ,
274+ save_path
275+ )
276+
277+ train_patch_pipe (model , use_unet_lora , use_text_lora )
222278
223279def param_optim (model , condition , extra_params = None , is_lora = False , negation = None ):
224280 return {
@@ -415,7 +471,7 @@ def save_pipe(
415471 use_text_lora ,
416472 unet_target_replace_module = None ,
417473 text_target_replace_module = None ,
418- is_checkpoint = False
474+ is_checkpoint = False ,
419475 ):
420476
421477 if is_checkpoint :
@@ -427,35 +483,39 @@ def save_pipe(
427483 # Save the dtypes so we can continue training at the same precision.
428484 u_dtype , t_dtype , v_dtype = unet .dtype , text_encoder .dtype , vae .dtype
429485
430- # We do this to prevent OOM during training when saving a checkpoint.
431- [x .to ('cpu' ) for x in [unet , text_encoder , vae ]]
486+ # Copy the model without creating a reference to it. This allows keeping the state of our lora training if enabled.
487+ unet_out = copy .deepcopy (accelerator .unwrap_model (unet , keep_fp32_wrapper = False ))
488+ text_encoder_out = copy .deepcopy (accelerator .unwrap_model (text_encoder , keep_fp32_wrapper = False ))
432489
433490 pipeline = TextToVideoSDPipeline .from_pretrained (
434491 path ,
435- unet = unet ,
436- text_encoder = text_encoder ,
492+ unet = unet_out ,
493+ text_encoder = text_encoder_out ,
437494 vae = vae ,
438- )
495+ ). to ( torch_dtype = torch . float16 )
439496
440497 handle_lora_save (
441- use_unet_lora , use_text_lora ,
498+ use_unet_lora ,
499+ use_text_lora ,
442500 pipeline ,
443- output_dir ,
501+ output_dir ,
444502 global_step ,
445- unet_target_replace_module ,
446- text_target_replace_module ,
447- end_train = not is_checkpoint
503+ unet_target_replace_module ,
504+ text_target_replace_module
448505 )
449506
450507 pipeline .save_pretrained (save_path )
451508
452509 if is_checkpoint :
510+ unet , text_encoder = accelerator .prepare (unet , text_encoder )
453511 models_to_cast_back = [(unet , u_dtype ), (text_encoder , t_dtype ), (vae , v_dtype )]
454512 [x [0 ].to (accelerator .device , dtype = x [1 ]) for x in models_to_cast_back ]
455513
456514 logger .info (f"Saved model at { save_path } on step { global_step } " )
457515
458516 del pipeline
517+ del unet_out
518+ del text_encoder_out
459519 torch .cuda .empty_cache ()
460520 gc .collect ()
461521
@@ -544,6 +604,7 @@ def main(
544604 unet_lora_modules : Tuple [str ] = ["ResnetBlock2D" ],
545605 text_encoder_lora_modules : Tuple [str ] = ["CLIPEncoderLayer" ],
546606 lora_rank : int = 16 ,
607+ lora_path : str = '' ,
547608 ** kwargs
548609):
549610
@@ -590,12 +651,12 @@ def main(
590651 # Use LoRA if enabled.
591652 unet_lora_params , unet_negation = inject_lora (
592653 use_unet_lora , unet , unet_lora_modules , is_extended = True ,
593- rank = lora_rank
654+ r = lora_rank , lora_path = lora_path
594655 )
595656
596657 text_encoder_lora_params , text_encoder_negation = inject_lora (
597658 use_text_lora , text_encoder , text_encoder_lora_modules ,
598- rank = lora_rank
659+ r = lora_rank , lora_path = lora_path
599660 )
600661
601662 # Create parameters to optimize over with a condition (if "condition" is true, optimize it)
@@ -604,8 +665,8 @@ def main(
604665 param_optim (text_encoder , train_text_encoder and not use_text_lora , extra_params = extra_text_encoder_params ,
605666 negation = text_encoder_negation
606667 ),
607- param_optim (text_encoder_lora_params , use_text_lora , is_lora = True , extra_params = {"lr" : 5e -5 }),
608- param_optim (unet_lora_params , use_unet_lora , is_lora = True , extra_params = {"lr" : 5e-6 })
668+ param_optim (text_encoder_lora_params , use_text_lora , is_lora = True , extra_params = {"lr" : 1e -5 }),
669+ param_optim (unet_lora_params , use_unet_lora , is_lora = True , extra_params = {"lr" : 1e-5 })
609670 ]
610671
611672 params = create_optimizer_params (optim_params , learning_rate )
@@ -885,8 +946,8 @@ def finetune_unet(batch, train_encoder=False):
885946 output_dir ,
886947 use_unet_lora ,
887948 use_text_lora ,
888- unet_target_replace_module = unet_lora_modules ,
889- text_target_replace_module = text_encoder_lora_modules ,
949+ unet_lora_modules ,
950+ text_encoder_lora_modules ,
890951 is_checkpoint = True
891952 )
892953
@@ -959,8 +1020,10 @@ def finetune_unet(batch, train_encoder=False):
9591020 output_dir ,
9601021 use_unet_lora ,
9611022 use_text_lora ,
1023+ unet_lora_modules ,
1024+ text_encoder_lora_modules ,
9621025 is_checkpoint = False
963- )
1026+ )
9641027 accelerator .end_training ()
9651028
9661029if __name__ == "__main__" :
0 commit comments