@@ -235,83 +235,6 @@ def set_selector_from_diag(self, diag: torch.Tensor):
235235 self .lora_up .weight .device
236236 ).to (self .lora_up .weight .dtype )
237237
238- class LoraInjectedConv3d (nn .Module ):
239- def __init__ (
240- self ,
241- in_channels : int ,
242- out_channels : int ,
243- kernel_size : (3 , 1 , 1 ),
244- padding : (1 , 0 , 0 ),
245- bias : bool = False ,
246- r : int = 4 ,
247- dropout_p : float = 0 ,
248- scale : float = 1.0 ,
249- ):
250- super ().__init__ ()
251- if r > min (in_channels , out_channels ):
252- raise ValueError (
253- f"LoRA rank { r } must be less or equal than { min (in_channels , out_channels )} "
254- )
255-
256- self .r = r
257- self .kernel_size = kernel_size
258- self .padding = padding
259- self .conv = nn .Conv3d (
260- in_channels = in_channels ,
261- out_channels = out_channels ,
262- kernel_size = kernel_size ,
263- padding = padding ,
264- )
265-
266- self .lora_down = nn .Conv3d (
267- in_channels = in_channels ,
268- out_channels = r ,
269- kernel_size = kernel_size ,
270- bias = False ,
271- padding = padding
272- )
273- self .dropout = nn .Dropout (dropout_p )
274- self .lora_up = nn .Conv3d (
275- in_channels = r ,
276- out_channels = out_channels ,
277- kernel_size = kernel_size ,
278- bias = False ,
279- padding = padding
280- )
281- self .selector = nn .Identity ()
282- self .scale = scale
283-
284- nn .init .normal_ (self .lora_down .weight , std = 1 / r )
285- nn .init .zeros_ (self .lora_up .weight )
286-
287- def forward (self , input ):
288- return (
289- self .conv (input )
290- + self .dropout (self .lora_up (self .selector (self .lora_down (input ))))
291- * self .scale
292- )
293-
294- def realize_as_lora (self ):
295- return self .lora_up .weight .data * self .scale , self .lora_down .weight .data
296-
297- def set_selector_from_diag (self , diag : torch .Tensor ):
298- # diag is a 1D tensor of size (r,)
299- assert diag .shape == (self .r ,)
300- self .selector = nn .Conv3d (
301- in_channels = self .r ,
302- out_channels = self .r ,
303- kernel_size = self .kernel_size ,
304- bias = False ,
305- padding = self .padding
306- )
307- self .selector .weight .data = torch .diag (diag )
308-
309- # same device + dtype as lora_up
310- self .selector .weight .data = self .selector .weight .data .to (
311- self .lora_up .weight .device
312- ).to (self .lora_up .weight .dtype )
313-
314-
315238UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention" , "Attention" , "GEGLU" }
316239
317240UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D" , "CrossAttention" , "Attention" , "GEGLU" }
@@ -558,7 +481,7 @@ def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
558481 for _m , _n , _child_module in _find_modules (
559482 model ,
560483 target_replace_module ,
561- search_class = [LoraInjectedLinear , LoraInjectedConv2d ],
484+ search_class = [LoraInjectedLinear , LoraInjectedConv2d , LoraInjectedConv3d ],
562485 ):
563486 loras .append ((_child_module .lora_up , _child_module .lora_down ))
564487
@@ -577,7 +500,7 @@ def extract_lora_as_tensor(
577500 for _m , _n , _child_module in _find_modules (
578501 model ,
579502 target_replace_module ,
580- search_class = [LoraInjectedLinear , LoraInjectedConv2d ],
503+ search_class = [LoraInjectedLinear , LoraInjectedConv2d , LoraInjectedConv3d ],
581504 ):
582505 up , down = _child_module .realize_as_lora ()
583506 if as_fp16 :
@@ -601,8 +524,8 @@ def save_lora_weight(
601524 for _up , _down in extract_lora_ups_down (
602525 model , target_replace_module = target_replace_module
603526 ):
604- weights .append (_up .weight .to ("cpu" ).to (torch .float16 ))
605- weights .append (_down .weight .to ("cpu" ).to (torch .float16 ))
527+ weights .append (_up .weight .to ("cpu" ).to (torch .float32 ))
528+ weights .append (_down .weight .to ("cpu" ).to (torch .float32 ))
606529
607530 torch .save (weights , path )
608531
@@ -893,7 +816,13 @@ def monkeypatch_or_replace_lora_extended(
893816 for _module , name , _child_module in _find_modules (
894817 model ,
895818 target_replace_module ,
896- search_class = [nn .Linear , LoraInjectedLinear , nn .Conv2d , LoraInjectedConv2d , LoraInjectedConv3d ],
819+ search_class = [
820+ nn .Linear ,
821+ LoraInjectedLinear ,
822+ nn .Conv2d ,
823+ LoraInjectedConv2d ,
824+ LoraInjectedConv3d
825+ ],
897826 ):
898827
899828 if (_child_module .__class__ == nn .Linear ) or (
@@ -951,6 +880,35 @@ def monkeypatch_or_replace_lora_extended(
951880 if bias is not None :
952881 _tmp .conv .bias = bias
953882
883+ elif _child_module .__class__ == nn .Conv3d or (
884+ _child_module .__class__ == LoraInjectedConv3d
885+ ):
886+
887+ if len (loras [0 ].shape ) != 5 :
888+ continue
889+
890+ _source = (
891+ _child_module .conv
892+ if isinstance (_child_module , LoraInjectedConv3d )
893+ else _child_module
894+ )
895+
896+ weight = _source .weight
897+ bias = _source .bias
898+ _tmp = LoraInjectedConv3d (
899+ _source .in_channels ,
900+ _source .out_channels ,
901+ bias = _source .bias is not None ,
902+ kernel_size = _source .kernel_size ,
903+ padding = _source .padding ,
904+ r = r .pop (0 ) if isinstance (r , list ) else r ,
905+ )
906+
907+ _tmp .conv .weight = weight
908+
909+ if bias is not None :
910+ _tmp .conv .bias = bias
911+
954912 # switch the module
955913 _module ._modules [name ] = _tmp
956914
@@ -1000,20 +958,35 @@ def monkeypatch_remove_lora(model):
1000958 _source = _child_module .conv
1001959 weight , bias = _source .weight , _source .bias
1002960
1003- _tmp = nn .Conv2d (
1004- in_channels = _source .in_channels ,
1005- out_channels = _source .out_channels ,
961+ if isinstance (_source , nn .Conv2d ):
962+ _tmp = nn .Conv2d (
963+ in_channels = _source .in_channels ,
964+ out_channels = _source .out_channels ,
965+ kernel_size = _source .kernel_size ,
966+ stride = _source .stride ,
967+ padding = _source .padding ,
968+ dilation = _source .dilation ,
969+ groups = _source .groups ,
970+ bias = bias is not None ,
971+ )
972+
973+ _tmp .weight = weight
974+ if bias is not None :
975+ _tmp .bias = bias
976+
977+ if isinstance (_source , nn .Conv3d ):
978+ _tmp = nn .Conv3d (
979+ _source .in_channels ,
980+ _source .out_channels ,
981+ bias = _source .bias is not None ,
1006982 kernel_size = _source .kernel_size ,
1007- stride = _source .stride ,
1008983 padding = _source .padding ,
1009- dilation = _source .dilation ,
1010- groups = _source .groups ,
1011- bias = bias is not None ,
1012984 )
1013985
1014- _tmp .weight = weight
986+ _tmp .conv .weight = weight
987+
1015988 if bias is not None :
1016- _tmp .bias = bias
989+ _tmp .conv . bias = bias
1017990
1018991 _module ._modules [name ] = _tmp
1019992
@@ -1243,7 +1216,6 @@ def save_all(
12431216
12441217 # save text encoder
12451218 if save_lora :
1246-
12471219 save_lora_weight (
12481220 unet , save_path , target_replace_module = target_replace_module_unet
12491221 )
0 commit comments