|
1 | 1 | import json |
2 | 2 | import math |
3 | 3 | from itertools import groupby |
| 4 | +import os |
4 | 5 | from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union |
5 | 6 |
|
6 | 7 | import numpy as np |
@@ -474,6 +475,52 @@ def inject_trainable_lora_extended( |
474 | 475 | return require_grad_params, names |
475 | 476 |
|
476 | 477 |
|
| 478 | +def inject_inferable_lora( |
| 479 | + model, |
| 480 | + lora_path='', |
| 481 | + unet_replace_modules=["UNet3DConditionModel"], |
| 482 | + text_encoder_replace_modules=["CLIPEncoderLayer"], |
| 483 | + is_extended=False, |
| 484 | + r=16 |
| 485 | + ): |
| 486 | + from transformers.models.clip import CLIPTextModel |
| 487 | + from diffusers import UNet3DConditionModel |
| 488 | + |
| 489 | + def is_text_model(f): return 'text_encoder' in f and isinstance(model.text_encoder, CLIPTextModel) |
| 490 | + def is_unet(f): return 'unet' in f and model.unet.__class__.__name__ == "UNet3DConditionModel" |
| 491 | + |
| 492 | + if os.path.exists(lora_path): |
| 493 | + try: |
| 494 | + for f in os.listdir(lora_path): |
| 495 | + if f.endswith('.pt'): |
| 496 | + lora_file = os.path.join(lora_path, f) |
| 497 | + |
| 498 | + if is_text_model(f): |
| 499 | + monkeypatch_or_replace_lora( |
| 500 | + model.text_encoder, |
| 501 | + torch.load(lora_file), |
| 502 | + target_replace_module=text_encoder_replace_modules, |
| 503 | + r=r |
| 504 | + ) |
| 505 | + print("Successfully loaded Text Encoder LoRa.") |
| 506 | + continue |
| 507 | + |
| 508 | + if is_unet(f): |
| 509 | + monkeypatch_or_replace_lora_extended( |
| 510 | + model.unet, |
| 511 | + torch.load(lora_file), |
| 512 | + target_replace_module=unet_replace_modules, |
| 513 | + r=r |
| 514 | + ) |
| 515 | + print("Successfully loaded UNET LoRa.") |
| 516 | + continue |
| 517 | + |
| 518 | + print("Found a .pt file, but doesn't have the correct name format. (unet.pt, text_encoder.pt)") |
| 519 | + |
| 520 | + except Exception as e: |
| 521 | + print(e) |
| 522 | + print("Couldn't inject LoRA's due to an error.") |
| 523 | + |
477 | 524 | def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE): |
478 | 525 |
|
479 | 526 | loras = [] |
|
0 commit comments