Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit 5ba1085

Browse files
committed
fix encoder conversion by using base model
using existing clip in modelscope format change the changed layers
1 parent 9b14bbe commit 5ba1085

File tree

1 file changed

+55
-38
lines changed

1 file changed

+55
-38
lines changed

utils/convert_diffusers_to_original_ms_text_to_video.py

Lines changed: 55 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os.path as osp
77
import re
88

9-
import torch
9+
import torch, gc
1010
from safetensors.torch import load_file, save_file
1111

1212
# =================#
@@ -369,6 +369,7 @@ def convert_text_enc_state_dict(text_enc_dict):
369369
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
370370
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
371371
parser.add_argument("--clip_checkpoint_path", default=None, type=str, help="Path to the output CLIP model.")
372+
parser.add_argument("--clip_base_path", default=None, type=str, help="Path to the source original ModelScope-format (!) CLIP model.")
372373
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
373374
parser.add_argument(
374375
"--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
@@ -380,7 +381,9 @@ def convert_text_enc_state_dict(text_enc_dict):
380381

381382
assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
382383

383-
assert args.clip_checkpoint_path is not None, "Must provide a CLIP checkpoint path!"
384+
assert args.clip_checkpoint_path is not None, "Must provide a CLIP checkpoint output path!"
385+
386+
assert args.clip_base_path is not None, "Must provide an existing original ModelScope format (!) CLIP checkpoint path!"
384387

385388
# Path for safetensors
386389
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
@@ -394,31 +397,43 @@ def convert_text_enc_state_dict(text_enc_dict):
394397
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
395398
unet_state_dict = torch.load(unet_path, map_location="cpu")
396399

397-
# if osp.exists(vae_path):
398-
# vae_state_dict = load_file(vae_path, device="cpu")
399-
# else:
400-
# vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
401-
# vae_state_dict = torch.load(vae_path, map_location="cpu")
400+
# Convert the UNet model
401+
unet_state_dict = convert_unet_state_dict(unet_state_dict)
402+
#unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
403+
404+
print ('Saving UNET')
405+
unet_state_dict = {**unet_state_dict}
406+
407+
if args.half:
408+
unet_state_dict = {k: v.half() for k, v in unet_state_dict.items()}
409+
410+
if args.use_safetensors:
411+
save_file(unet_state_dict, args.checkpoint_path)
412+
else:
413+
#state_dict = {"state_dict": state_dict}
414+
torch.save(unet_state_dict, args.checkpoint_path)
415+
416+
del unet_state_dict
417+
gc.collect()
418+
print ('UNET Saved')
419+
420+
print ('Converting CLIP')
402421

422+
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
403423
if osp.exists(text_enc_path):
404424
text_enc_dict = load_file(text_enc_path, device="cpu")
405425
else:
406426
text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
407427
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
408428

409-
# Convert the UNet model
410-
unet_state_dict = convert_unet_state_dict(unet_state_dict)
411-
#unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
429+
with open('l.txt', 'w') as deb:
430+
text = '\n'.join(text_enc_dict.keys())
431+
deb.write(text)
412432

413-
# Convert the VAE model
414-
# vae_state_dict = convert_vae_state_dict(vae_state_dict)
415-
# vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
416-
417-
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
418433
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
419434

435+
print(is_v20_model)
420436
if is_v20_model:
421-
422437
# MODELSCOPE always uses the 2.X encoder, btw --kabachuha
423438

424439
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
@@ -429,33 +444,35 @@ def convert_text_enc_state_dict(text_enc_dict):
429444
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
430445
#text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
431446

432-
# DON'T PUT TOGETHER FOR THE NEW CHECKPOINT AS MODELSCOPE USES THEM IN THE SPLITTED FORM --kabachuha
433-
# Save CLIP and the Diffusion model to their own files
447+
clip_base_path = args.clip_base_path
434448

435-
#state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
436-
print ('Saving UNET')
437-
state_dict = {**unet_state_dict}
449+
# HACK: grab a preexisting openclip model and change only the converted layers
450+
#if osp.exists(clip_base_path):
451+
# text_enc_dict_base = load_file(clip_base_path, device="cpu")
452+
#else:
453+
text_enc_dict_base = torch.load(clip_base_path, map_location="cpu")
438454

439-
if args.half:
440-
state_dict = {k: v.half() for k, v in state_dict.items()}
455+
print ('Changing the changed CLIP layers')
456+
for k, v in text_enc_dict.items():
457+
if k in text_enc_dict_base:
458+
text_enc_dict_base[k] = v
441459

442-
if args.use_safetensors:
443-
save_file(state_dict, args.checkpoint_path)
444-
else:
445-
#state_dict = {"state_dict": state_dict}
446-
torch.save(state_dict, args.checkpoint_path)
460+
text_enc_dict = text_enc_dict_base
447461

448-
# TODO: CLIP conversion doesn't work atm
449-
# print ('Saving CLIP')
450-
# state_dict = {**text_enc_dict}
462+
print ('Saving CLIP')
463+
text_enc_dict = {**text_enc_dict}
451464

452-
# if args.half:
453-
# state_dict = {k: v.half() for k, v in state_dict.items()}
465+
#with open('l1.txt', 'w') as deb:
466+
# text = '\n'.join(text_enc_dict.keys())
467+
# deb.write(text)
454468

455-
# if args.use_safetensors:
456-
# save_file(state_dict, args.checkpoint_path)
457-
# else:
458-
# #state_dict = {"state_dict": state_dict}
459-
# torch.save(state_dict, args.clip_checkpoint_path)
469+
if args.half:
470+
text_enc_dict = {k: v.half() for k, v in text_enc_dict.items()}
471+
472+
if args.use_safetensors:
473+
save_file(text_enc_dict, args.checkpoint_path)
474+
else:
475+
#state_dict = {"state_dict": text_enc_dict}
476+
torch.save(text_enc_dict, args.clip_checkpoint_path)
460477

461478
print('Operation successfull')

0 commit comments

Comments
 (0)