From 5ba1085fb7cd9076c5cabb852a1b871454ff765a Mon Sep 17 00:00:00 2001 From: kabachuha Date: Sat, 10 Jun 2023 01:40:01 +0300 Subject: [PATCH] fix encoder conversion by using base model using existing clip in modelscope format change the changed layers --- ..._diffusers_to_original_ms_text_to_video.py | 93 +++++++++++-------- 1 file changed, 55 insertions(+), 38 deletions(-) diff --git a/utils/convert_diffusers_to_original_ms_text_to_video.py b/utils/convert_diffusers_to_original_ms_text_to_video.py index 5252e4f..6b6f3a7 100644 --- a/utils/convert_diffusers_to_original_ms_text_to_video.py +++ b/utils/convert_diffusers_to_original_ms_text_to_video.py @@ -6,7 +6,7 @@ import os.path as osp import re -import torch +import torch, gc from safetensors.torch import load_file, save_file # =================# @@ -369,6 +369,7 @@ def convert_text_enc_state_dict(text_enc_dict): parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") parser.add_argument("--clip_checkpoint_path", default=None, type=str, help="Path to the output CLIP model.") + parser.add_argument("--clip_base_path", default=None, type=str, help="Path to the source original ModelScope-format (!) CLIP model.") parser.add_argument("--half", action="store_true", help="Save weights in half precision.") parser.add_argument( "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt." @@ -380,7 +381,9 @@ def convert_text_enc_state_dict(text_enc_dict): assert args.checkpoint_path is not None, "Must provide a checkpoint path!" - assert args.clip_checkpoint_path is not None, "Must provide a CLIP checkpoint path!" + assert args.clip_checkpoint_path is not None, "Must provide a CLIP checkpoint output path!" + + assert args.clip_base_path is not None, "Must provide an existing original ModelScope format (!) CLIP checkpoint path!" # Path for safetensors unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors") @@ -394,31 +397,43 @@ def convert_text_enc_state_dict(text_enc_dict): unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") unet_state_dict = torch.load(unet_path, map_location="cpu") - # if osp.exists(vae_path): - # vae_state_dict = load_file(vae_path, device="cpu") - # else: - # vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") - # vae_state_dict = torch.load(vae_path, map_location="cpu") + # Convert the UNet model + unet_state_dict = convert_unet_state_dict(unet_state_dict) + #unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} + + print ('Saving UNET') + unet_state_dict = {**unet_state_dict} + + if args.half: + unet_state_dict = {k: v.half() for k, v in unet_state_dict.items()} + + if args.use_safetensors: + save_file(unet_state_dict, args.checkpoint_path) + else: + #state_dict = {"state_dict": state_dict} + torch.save(unet_state_dict, args.checkpoint_path) + + del unet_state_dict + gc.collect() + print ('UNET Saved') + + print ('Converting CLIP') + # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper if osp.exists(text_enc_path): text_enc_dict = load_file(text_enc_path, device="cpu") else: text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") text_enc_dict = torch.load(text_enc_path, map_location="cpu") - # Convert the UNet model - unet_state_dict = convert_unet_state_dict(unet_state_dict) - #unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} + with open('l.txt', 'w') as deb: + text = '\n'.join(text_enc_dict.keys()) + deb.write(text) - # Convert the VAE model - # vae_state_dict = convert_vae_state_dict(vae_state_dict) - # vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} - - # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict + print(is_v20_model) if is_v20_model: - # MODELSCOPE always uses the 2.X encoder, btw --kabachuha # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm @@ -429,33 +444,35 @@ def convert_text_enc_state_dict(text_enc_dict): text_enc_dict = convert_text_enc_state_dict(text_enc_dict) #text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} - # DON'T PUT TOGETHER FOR THE NEW CHECKPOINT AS MODELSCOPE USES THEM IN THE SPLITTED FORM --kabachuha - # Save CLIP and the Diffusion model to their own files + clip_base_path = args.clip_base_path - #state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} - print ('Saving UNET') - state_dict = {**unet_state_dict} + # HACK: grab a preexisting openclip model and change only the converted layers + #if osp.exists(clip_base_path): + # text_enc_dict_base = load_file(clip_base_path, device="cpu") + #else: + text_enc_dict_base = torch.load(clip_base_path, map_location="cpu") - if args.half: - state_dict = {k: v.half() for k, v in state_dict.items()} + print ('Changing the changed CLIP layers') + for k, v in text_enc_dict.items(): + if k in text_enc_dict_base: + text_enc_dict_base[k] = v - if args.use_safetensors: - save_file(state_dict, args.checkpoint_path) - else: - #state_dict = {"state_dict": state_dict} - torch.save(state_dict, args.checkpoint_path) + text_enc_dict = text_enc_dict_base - # TODO: CLIP conversion doesn't work atm - # print ('Saving CLIP') - # state_dict = {**text_enc_dict} + print ('Saving CLIP') + text_enc_dict = {**text_enc_dict} - # if args.half: - # state_dict = {k: v.half() for k, v in state_dict.items()} + #with open('l1.txt', 'w') as deb: + # text = '\n'.join(text_enc_dict.keys()) + # deb.write(text) - # if args.use_safetensors: - # save_file(state_dict, args.checkpoint_path) - # else: - # #state_dict = {"state_dict": state_dict} - # torch.save(state_dict, args.clip_checkpoint_path) + if args.half: + text_enc_dict = {k: v.half() for k, v in text_enc_dict.items()} + + if args.use_safetensors: + save_file(text_enc_dict, args.checkpoint_path) + else: + #state_dict = {"state_dict": text_enc_dict} + torch.save(text_enc_dict, args.clip_checkpoint_path) print('Operation successfull')