Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit 5b33fc8

Browse files
committed
Add LoRA inference.
1 parent 8dbd2ff commit 5b33fc8

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

inference.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44
from pathlib import Path
55
from uuid import uuid4
6-
6+
from utils.lora import inject_inferable_lora
77
import torch
88
from diffusers import DPMSolverMultistepScheduler, TextToVideoSDPipeline
99
from einops import rearrange
@@ -111,10 +111,12 @@ def inference(
111111
device="cuda",
112112
xformers=False,
113113
sdp=False,
114+
lora_path='',
115+
lora_rank=64
114116
):
115117
with torch.autocast(device, dtype=torch.half):
116118
pipeline = initialize_pipeline(model, device, xformers, sdp)
117-
119+
inject_inferable_lora(pipeline, lora_path, lora_rank)
118120
prompt = [prompt] * batch_size
119121
negative_prompt = ([negative_prompt] * batch_size) if negative_prompt is not None else None
120122

@@ -168,6 +170,8 @@ def inference(
168170
parser.add_argument("-d", "--device", type=str, default="cuda")
169171
parser.add_argument("-x", "--xformers", action="store_true")
170172
parser.add_argument("-S", "--sdp", action="store_true")
173+
parser.add_argument("-lP", "--lora_path", type=str, default="")
174+
parser.add_argument("-lR", "--lora_rank", type=int, default=64)
171175
parser.add_argument("-rw", "--remove-watermark", action="store_true")
172176
args = vars(parser.parse_args())
173177

utils/lora.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,52 @@ def inject_trainable_lora_extended(
474474
return require_grad_params, names
475475

476476

477+
def inject_inferable_lora(
478+
model,
479+
lora_path='',
480+
unet_replace_modules=["UNet3DConditionModel"],
481+
text_encoder_replace_modules=["CLIPTextModel"],
482+
is_extended=False,
483+
r=16
484+
):
485+
from transformers.models.clip import CLIPTextModel
486+
from diffusers import UNet3DConditionModel
487+
488+
def is_text_model(f): return 'text_encoder' in f and isinstance(model, CLIPTextModel)
489+
def is_unet(f): return 'unet' in f and isinstance(model, UNet3DConditionModel)
490+
491+
if os.path.exists(lora_path):
492+
try:
493+
for f in os.listdir(lora_path):
494+
if f.endswith('.pt'):
495+
lora_file = os.path.join(lora_path, f)
496+
497+
if is_text_model(f):
498+
monkeypatch_or_replace_lora(
499+
model,
500+
torch.load(lora_file),
501+
target_replace_module=unet_replace_modules,
502+
r=r
503+
)
504+
print("Successfully loaded Text Encoder LoRa.")
505+
return
506+
507+
if is_unet(f):
508+
monkeypatch_or_replace_lora_extended(
509+
model,
510+
torch.load(lora_file),
511+
target_replace_module=text_encoder_replace_modules,
512+
r=r
513+
)
514+
print("Successfully loaded UNET LoRa.")
515+
return
516+
517+
print("Found a .pt file, but doesn't have the correct name format. (unet.pt, text_encoder.pt)")
518+
519+
except Exception as e:
520+
print(e)
521+
print("Couldn't inject LoRA's due to an error.")
522+
477523
def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
478524

479525
loras = []

0 commit comments

Comments
 (0)