Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit b5797d4

Browse files
committed
Update lora args, imports, and small fixes
1 parent 5b33fc8 commit b5797d4

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def inference(
116116
):
117117
with torch.autocast(device, dtype=torch.half):
118118
pipeline = initialize_pipeline(model, device, xformers, sdp)
119-
inject_inferable_lora(pipeline, lora_path, lora_rank)
119+
inject_inferable_lora(pipeline, lora_path, r=lora_rank)
120120
prompt = [prompt] * batch_size
121121
negative_prompt = ([negative_prompt] * batch_size) if negative_prompt is not None else None
122122

utils/lora.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import math
33
from itertools import groupby
4+
import os
45
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
56

67
import numpy as np
@@ -485,34 +486,34 @@ def inject_inferable_lora(
485486
from transformers.models.clip import CLIPTextModel
486487
from diffusers import UNet3DConditionModel
487488

488-
def is_text_model(f): return 'text_encoder' in f and isinstance(model, CLIPTextModel)
489-
def is_unet(f): return 'unet' in f and isinstance(model, UNet3DConditionModel)
489+
def is_text_model(f): return 'text_encoder' in f and isinstance(model.text_encoder, CLIPTextModel)
490+
def is_unet(f): return 'unet' in f and model.unet.__class__.__name__ == "UNet3DConditionModel"
490491

491492
if os.path.exists(lora_path):
492493
try:
493494
for f in os.listdir(lora_path):
494495
if f.endswith('.pt'):
495496
lora_file = os.path.join(lora_path, f)
496-
497+
497498
if is_text_model(f):
498499
monkeypatch_or_replace_lora(
499-
model,
500+
model.text_encoder,
500501
torch.load(lora_file),
501502
target_replace_module=unet_replace_modules,
502503
r=r
503504
)
504505
print("Successfully loaded Text Encoder LoRa.")
505-
return
506-
506+
continue
507+
507508
if is_unet(f):
508509
monkeypatch_or_replace_lora_extended(
509-
model,
510+
model.unet,
510511
torch.load(lora_file),
511512
target_replace_module=text_encoder_replace_modules,
512513
r=r
513514
)
514515
print("Successfully loaded UNET LoRa.")
515-
return
516+
continue
516517

517518
print("Found a .pt file, but doesn't have the correct name format. (unet.pt, text_encoder.pt)")
518519

0 commit comments

Comments
 (0)