Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit 87beeaa

Browse files
Merge pull request #51 from ExponentialML/feat/lora-infer
Add LoRA Inference.
2 parents 8ce2e52 + 09ed1ae commit 87beeaa

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

inference.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44
from pathlib import Path
55
from uuid import uuid4
6-
6+
from utils.lora import inject_inferable_lora
77
import torch
88
from diffusers import DPMSolverMultistepScheduler, TextToVideoSDPipeline
99
from einops import rearrange
@@ -111,10 +111,12 @@ def inference(
111111
device="cuda",
112112
xformers=False,
113113
sdp=False,
114+
lora_path='',
115+
lora_rank=64
114116
):
115117
with torch.autocast(device, dtype=torch.half):
116118
pipeline = initialize_pipeline(model, device, xformers, sdp)
117-
119+
inject_inferable_lora(pipeline, lora_path, r=lora_rank)
118120
prompt = [prompt] * batch_size
119121
negative_prompt = ([negative_prompt] * batch_size) if negative_prompt is not None else None
120122

@@ -168,6 +170,8 @@ def inference(
168170
parser.add_argument("-d", "--device", type=str, default="cuda")
169171
parser.add_argument("-x", "--xformers", action="store_true")
170172
parser.add_argument("-S", "--sdp", action="store_true")
173+
parser.add_argument("-lP", "--lora_path", type=str, default="")
174+
parser.add_argument("-lR", "--lora_rank", type=int, default=64)
171175
parser.add_argument("-rw", "--remove-watermark", action="store_true")
172176
args = vars(parser.parse_args())
173177

utils/lora.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import math
33
from itertools import groupby
4+
import os
45
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
56

67
import numpy as np
@@ -474,6 +475,52 @@ def inject_trainable_lora_extended(
474475
return require_grad_params, names
475476

476477

478+
def inject_inferable_lora(
479+
model,
480+
lora_path='',
481+
unet_replace_modules=["UNet3DConditionModel"],
482+
text_encoder_replace_modules=["CLIPEncoderLayer"],
483+
is_extended=False,
484+
r=16
485+
):
486+
from transformers.models.clip import CLIPTextModel
487+
from diffusers import UNet3DConditionModel
488+
489+
def is_text_model(f): return 'text_encoder' in f and isinstance(model.text_encoder, CLIPTextModel)
490+
def is_unet(f): return 'unet' in f and model.unet.__class__.__name__ == "UNet3DConditionModel"
491+
492+
if os.path.exists(lora_path):
493+
try:
494+
for f in os.listdir(lora_path):
495+
if f.endswith('.pt'):
496+
lora_file = os.path.join(lora_path, f)
497+
498+
if is_text_model(f):
499+
monkeypatch_or_replace_lora(
500+
model.text_encoder,
501+
torch.load(lora_file),
502+
target_replace_module=text_encoder_replace_modules,
503+
r=r
504+
)
505+
print("Successfully loaded Text Encoder LoRa.")
506+
continue
507+
508+
if is_unet(f):
509+
monkeypatch_or_replace_lora_extended(
510+
model.unet,
511+
torch.load(lora_file),
512+
target_replace_module=unet_replace_modules,
513+
r=r
514+
)
515+
print("Successfully loaded UNET LoRa.")
516+
continue
517+
518+
print("Found a .pt file, but doesn't have the correct name format. (unet.pt, text_encoder.pt)")
519+
520+
except Exception as e:
521+
print(e)
522+
print("Couldn't inject LoRA's due to an error.")
523+
477524
def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
478525

479526
loras = []

0 commit comments

Comments
 (0)