Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit 07eb5bb

Browse files
Fix lora saving
1 parent 7640ce3 commit 07eb5bb

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

train.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -210,14 +210,15 @@ def inject_lora(use_lora, model, replace_modules, is_extended=False, rank: int =
210210

211211
return params, negation
212212

213-
def handle_lora_save(use_unet_lora, use_text_lora, model):
214-
if use_unet_lora:
215-
collapse_lora(model.unet)
216-
monkeypatch_remove_lora(model.unet)
217-
218-
if use_text_lora:
219-
collapse_lora(model.text_encoder)
220-
monkeypatch_remove_lora(model.text_encoder)
213+
def handle_lora_save(use_unet_lora, use_text_lora, model, end_train=False):
214+
if end_train:
215+
if use_unet_lora:
216+
collapse_lora(model.unet)
217+
monkeypatch_remove_lora(model.unet)
218+
219+
if use_text_lora:
220+
collapse_lora(model.text_encoder)
221+
monkeypatch_remove_lora(model.text_encoder)
221222

222223
def param_optim(model, condition, extra_params=None, is_lora=False, negation=None):
223224
return {

0 commit comments

Comments
 (0)