Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit 6b77294

Browse files
committed
Add train patch pipe
1 parent 3bca4aa commit 6b77294

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

utils/lora.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def save_lora_weight(
519519
model,
520520
path="./lora.pt",
521521
target_replace_module=DEFAULT_TARGET_REPLACE,
522-
):
522+
):
523523
weights = []
524524
for _up, _down in extract_lora_ups_down(
525525
model, target_replace_module=target_replace_module
@@ -990,7 +990,6 @@ def monkeypatch_remove_lora(model):
990990
_module._modules[name] = _tmp
991991

992992

993-
994993
def monkeypatch_add_lora(
995994
model,
996995
loras,
@@ -1166,6 +1165,18 @@ def patch_pipe(
11661165
return tok_dict
11671166

11681167

1168+
def train_patch_pipe(pipe, patch_unet, patch_text):
1169+
if patch_unet:
1170+
print("LoRA : Patching Unet")
1171+
collapse_lora(pipe.unet)
1172+
monkeypatch_remove_lora(pipe.unet)
1173+
1174+
if patch_text:
1175+
print("LoRA : Patching text encoder")
1176+
1177+
collapse_lora(pipe.text_encoder)
1178+
monkeypatch_remove_lora(pipe.text_encoder)
1179+
11691180
@torch.no_grad()
11701181
def inspect_lora(model):
11711182
moved = {}

0 commit comments

Comments
 (0)