Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 55 additions & 38 deletions utils/convert_diffusers_to_original_ms_text_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os.path as osp
import re

import torch
import torch, gc
from safetensors.torch import load_file, save_file

# =================#
Expand Down Expand Up @@ -369,6 +369,7 @@ def convert_text_enc_state_dict(text_enc_dict):
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument("--clip_checkpoint_path", default=None, type=str, help="Path to the output CLIP model.")
parser.add_argument("--clip_base_path", default=None, type=str, help="Path to the source original ModelScope-format (!) CLIP model.")
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
parser.add_argument(
"--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
Expand All @@ -380,7 +381,9 @@ def convert_text_enc_state_dict(text_enc_dict):

assert args.checkpoint_path is not None, "Must provide a checkpoint path!"

assert args.clip_checkpoint_path is not None, "Must provide a CLIP checkpoint path!"
assert args.clip_checkpoint_path is not None, "Must provide a CLIP checkpoint output path!"

assert args.clip_base_path is not None, "Must provide an existing original ModelScope format (!) CLIP checkpoint path!"

# Path for safetensors
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
Expand All @@ -394,31 +397,43 @@ def convert_text_enc_state_dict(text_enc_dict):
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
unet_state_dict = torch.load(unet_path, map_location="cpu")

# if osp.exists(vae_path):
# vae_state_dict = load_file(vae_path, device="cpu")
# else:
# vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
# vae_state_dict = torch.load(vae_path, map_location="cpu")
# Convert the UNet model
unet_state_dict = convert_unet_state_dict(unet_state_dict)
#unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}

print ('Saving UNET')
unet_state_dict = {**unet_state_dict}

if args.half:
unet_state_dict = {k: v.half() for k, v in unet_state_dict.items()}

if args.use_safetensors:
save_file(unet_state_dict, args.checkpoint_path)
else:
#state_dict = {"state_dict": state_dict}
torch.save(unet_state_dict, args.checkpoint_path)

del unet_state_dict
gc.collect()
print ('UNET Saved')

print ('Converting CLIP')

# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
if osp.exists(text_enc_path):
text_enc_dict = load_file(text_enc_path, device="cpu")
else:
text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
text_enc_dict = torch.load(text_enc_path, map_location="cpu")

# Convert the UNet model
unet_state_dict = convert_unet_state_dict(unet_state_dict)
#unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
with open('l.txt', 'w') as deb:
text = '\n'.join(text_enc_dict.keys())
deb.write(text)

# Convert the VAE model
# vae_state_dict = convert_vae_state_dict(vae_state_dict)
# vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}

# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict

print(is_v20_model)
if is_v20_model:

# MODELSCOPE always uses the 2.X encoder, btw --kabachuha

# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
Expand All @@ -429,33 +444,35 @@ def convert_text_enc_state_dict(text_enc_dict):
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
#text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}

# DON'T PUT TOGETHER FOR THE NEW CHECKPOINT AS MODELSCOPE USES THEM IN THE SPLITTED FORM --kabachuha
# Save CLIP and the Diffusion model to their own files
clip_base_path = args.clip_base_path

#state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
print ('Saving UNET')
state_dict = {**unet_state_dict}
# HACK: grab a preexisting openclip model and change only the converted layers
#if osp.exists(clip_base_path):
# text_enc_dict_base = load_file(clip_base_path, device="cpu")
#else:
text_enc_dict_base = torch.load(clip_base_path, map_location="cpu")

if args.half:
state_dict = {k: v.half() for k, v in state_dict.items()}
print ('Changing the changed CLIP layers')
for k, v in text_enc_dict.items():
if k in text_enc_dict_base:
text_enc_dict_base[k] = v

if args.use_safetensors:
save_file(state_dict, args.checkpoint_path)
else:
#state_dict = {"state_dict": state_dict}
torch.save(state_dict, args.checkpoint_path)
text_enc_dict = text_enc_dict_base

# TODO: CLIP conversion doesn't work atm
# print ('Saving CLIP')
# state_dict = {**text_enc_dict}
print ('Saving CLIP')
text_enc_dict = {**text_enc_dict}

# if args.half:
# state_dict = {k: v.half() for k, v in state_dict.items()}
#with open('l1.txt', 'w') as deb:
# text = '\n'.join(text_enc_dict.keys())
# deb.write(text)

# if args.use_safetensors:
# save_file(state_dict, args.checkpoint_path)
# else:
# #state_dict = {"state_dict": state_dict}
# torch.save(state_dict, args.clip_checkpoint_path)
if args.half:
text_enc_dict = {k: v.half() for k, v in text_enc_dict.items()}

if args.use_safetensors:
save_file(text_enc_dict, args.checkpoint_path)
else:
#state_dict = {"state_dict": text_enc_dict}
torch.save(text_enc_dict, args.clip_checkpoint_path)

print('Operation successfull')