Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit 8ed8714

Browse files
committed
Update lora util functions
1 parent aa5991e commit 8ed8714

File tree

1 file changed

+64
-92
lines changed

1 file changed

+64
-92
lines changed

utils/lora.py

Lines changed: 64 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -235,83 +235,6 @@ def set_selector_from_diag(self, diag: torch.Tensor):
235235
self.lora_up.weight.device
236236
).to(self.lora_up.weight.dtype)
237237

238-
class LoraInjectedConv3d(nn.Module):
239-
def __init__(
240-
self,
241-
in_channels: int,
242-
out_channels: int,
243-
kernel_size: (3, 1, 1),
244-
padding: (1, 0, 0),
245-
bias: bool = False,
246-
r: int = 4,
247-
dropout_p: float = 0,
248-
scale: float = 1.0,
249-
):
250-
super().__init__()
251-
if r > min(in_channels, out_channels):
252-
raise ValueError(
253-
f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}"
254-
)
255-
256-
self.r = r
257-
self.kernel_size = kernel_size
258-
self.padding = padding
259-
self.conv = nn.Conv3d(
260-
in_channels=in_channels,
261-
out_channels=out_channels,
262-
kernel_size=kernel_size,
263-
padding=padding,
264-
)
265-
266-
self.lora_down = nn.Conv3d(
267-
in_channels=in_channels,
268-
out_channels=r,
269-
kernel_size=kernel_size,
270-
bias=False,
271-
padding=padding
272-
)
273-
self.dropout = nn.Dropout(dropout_p)
274-
self.lora_up = nn.Conv3d(
275-
in_channels=r,
276-
out_channels=out_channels,
277-
kernel_size=kernel_size,
278-
bias=False,
279-
padding=padding
280-
)
281-
self.selector = nn.Identity()
282-
self.scale = scale
283-
284-
nn.init.normal_(self.lora_down.weight, std=1 / r)
285-
nn.init.zeros_(self.lora_up.weight)
286-
287-
def forward(self, input):
288-
return (
289-
self.conv(input)
290-
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
291-
* self.scale
292-
)
293-
294-
def realize_as_lora(self):
295-
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
296-
297-
def set_selector_from_diag(self, diag: torch.Tensor):
298-
# diag is a 1D tensor of size (r,)
299-
assert diag.shape == (self.r,)
300-
self.selector = nn.Conv3d(
301-
in_channels=self.r,
302-
out_channels=self.r,
303-
kernel_size=self.kernel_size,
304-
bias=False,
305-
padding=self.padding
306-
)
307-
self.selector.weight.data = torch.diag(diag)
308-
309-
# same device + dtype as lora_up
310-
self.selector.weight.data = self.selector.weight.data.to(
311-
self.lora_up.weight.device
312-
).to(self.lora_up.weight.dtype)
313-
314-
315238
UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
316239

317240
UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"}
@@ -558,7 +481,7 @@ def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
558481
for _m, _n, _child_module in _find_modules(
559482
model,
560483
target_replace_module,
561-
search_class=[LoraInjectedLinear, LoraInjectedConv2d],
484+
search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
562485
):
563486
loras.append((_child_module.lora_up, _child_module.lora_down))
564487

@@ -577,7 +500,7 @@ def extract_lora_as_tensor(
577500
for _m, _n, _child_module in _find_modules(
578501
model,
579502
target_replace_module,
580-
search_class=[LoraInjectedLinear, LoraInjectedConv2d],
503+
search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
581504
):
582505
up, down = _child_module.realize_as_lora()
583506
if as_fp16:
@@ -601,8 +524,8 @@ def save_lora_weight(
601524
for _up, _down in extract_lora_ups_down(
602525
model, target_replace_module=target_replace_module
603526
):
604-
weights.append(_up.weight.to("cpu").to(torch.float16))
605-
weights.append(_down.weight.to("cpu").to(torch.float16))
527+
weights.append(_up.weight.to("cpu").to(torch.float32))
528+
weights.append(_down.weight.to("cpu").to(torch.float32))
606529

607530
torch.save(weights, path)
608531

@@ -893,7 +816,13 @@ def monkeypatch_or_replace_lora_extended(
893816
for _module, name, _child_module in _find_modules(
894817
model,
895818
target_replace_module,
896-
search_class=[nn.Linear, LoraInjectedLinear, nn.Conv2d, LoraInjectedConv2d, LoraInjectedConv3d],
819+
search_class=[
820+
nn.Linear,
821+
LoraInjectedLinear,
822+
nn.Conv2d,
823+
LoraInjectedConv2d,
824+
LoraInjectedConv3d
825+
],
897826
):
898827

899828
if (_child_module.__class__ == nn.Linear) or (
@@ -951,6 +880,35 @@ def monkeypatch_or_replace_lora_extended(
951880
if bias is not None:
952881
_tmp.conv.bias = bias
953882

883+
elif _child_module.__class__ == nn.Conv3d or(
884+
_child_module.__class__ == LoraInjectedConv3d
885+
):
886+
887+
if len(loras[0].shape) != 5:
888+
continue
889+
890+
_source = (
891+
_child_module.conv
892+
if isinstance(_child_module, LoraInjectedConv3d)
893+
else _child_module
894+
)
895+
896+
weight = _source.weight
897+
bias = _source.bias
898+
_tmp = LoraInjectedConv3d(
899+
_source.in_channels,
900+
_source.out_channels,
901+
bias=_source.bias is not None,
902+
kernel_size=_source.kernel_size,
903+
padding=_source.padding,
904+
r=r.pop(0) if isinstance(r, list) else r,
905+
)
906+
907+
_tmp.conv.weight = weight
908+
909+
if bias is not None:
910+
_tmp.conv.bias = bias
911+
954912
# switch the module
955913
_module._modules[name] = _tmp
956914

@@ -1000,20 +958,35 @@ def monkeypatch_remove_lora(model):
1000958
_source = _child_module.conv
1001959
weight, bias = _source.weight, _source.bias
1002960

1003-
_tmp = nn.Conv2d(
1004-
in_channels=_source.in_channels,
1005-
out_channels=_source.out_channels,
961+
if isinstance(_source, nn.Conv2d):
962+
_tmp = nn.Conv2d(
963+
in_channels=_source.in_channels,
964+
out_channels=_source.out_channels,
965+
kernel_size=_source.kernel_size,
966+
stride=_source.stride,
967+
padding=_source.padding,
968+
dilation=_source.dilation,
969+
groups=_source.groups,
970+
bias=bias is not None,
971+
)
972+
973+
_tmp.weight = weight
974+
if bias is not None:
975+
_tmp.bias = bias
976+
977+
if isinstance(_source, nn.Conv3d):
978+
_tmp = nn.Conv3d(
979+
_source.in_channels,
980+
_source.out_channels,
981+
bias=_source.bias is not None,
1006982
kernel_size=_source.kernel_size,
1007-
stride=_source.stride,
1008983
padding=_source.padding,
1009-
dilation=_source.dilation,
1010-
groups=_source.groups,
1011-
bias=bias is not None,
1012984
)
1013985

1014-
_tmp.weight = weight
986+
_tmp.conv.weight = weight
987+
1015988
if bias is not None:
1016-
_tmp.bias = bias
989+
_tmp.conv.bias = bias
1017990

1018991
_module._modules[name] = _tmp
1019992

@@ -1243,7 +1216,6 @@ def save_all(
12431216

12441217
# save text encoder
12451218
if save_lora:
1246-
12471219
save_lora_weight(
12481220
unet, save_path, target_replace_module=target_replace_module_unet
12491221
)

0 commit comments

Comments
 (0)