Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit b2eee22

Browse files
committed
fix encoder conversion by using base model
using existing clip in modelscope format change the changed layers
1 parent 9b14bbe commit b2eee22

File tree

1 file changed

+39
-32
lines changed

1 file changed

+39
-32
lines changed

utils/convert_diffusers_to_original_ms_text_to_video.py

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os.path as osp
77
import re
88

9-
import torch
9+
import torch, gc
1010
from safetensors.torch import load_file, save_file
1111

1212
# =================#
@@ -400,25 +400,43 @@ def convert_text_enc_state_dict(text_enc_dict):
400400
# vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
401401
# vae_state_dict = torch.load(vae_path, map_location="cpu")
402402

403+
# Convert the UNet model
404+
unet_state_dict = convert_unet_state_dict(unet_state_dict)
405+
#unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
406+
407+
print ('Saving UNET')
408+
unet_state_dict = {**unet_state_dict}
409+
410+
if args.half:
411+
unet_state_dict = {k: v.half() for k, v in unet_state_dict.items()}
412+
413+
if args.use_safetensors:
414+
save_file(unet_state_dict, args.checkpoint_path)
415+
else:
416+
#state_dict = {"state_dict": state_dict}
417+
torch.save(unet_state_dict, args.checkpoint_path)
418+
419+
del unet_state_dict
420+
gc.collect()
421+
print ('UNET Saved')
422+
423+
print ('Converting CLIP')
424+
425+
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
403426
if osp.exists(text_enc_path):
404427
text_enc_dict = load_file(text_enc_path, device="cpu")
405428
else:
406429
text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
407430
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
408431

409-
# Convert the UNet model
410-
unet_state_dict = convert_unet_state_dict(unet_state_dict)
411-
#unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
412-
413-
# Convert the VAE model
414-
# vae_state_dict = convert_vae_state_dict(vae_state_dict)
415-
# vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
432+
with open('l.txt', 'w') as deb:
433+
text = '\n'.join(text_enc_dict.keys())
434+
deb.write(text)
416435

417-
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
418436
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
419437

438+
print(is_v20_model)
420439
if is_v20_model:
421-
422440
# MODELSCOPE always uses the 2.X encoder, btw --kabachuha
423441

424442
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
@@ -429,33 +447,22 @@ def convert_text_enc_state_dict(text_enc_dict):
429447
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
430448
#text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
431449

432-
# DON'T PUT TOGETHER FOR THE NEW CHECKPOINT AS MODELSCOPE USES THEM IN THE SPLITTED FORM --kabachuha
433-
# Save CLIP and the Diffusion model to their own files
434450

435-
#state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
436-
print ('Saving UNET')
437-
state_dict = {**unet_state_dict}
451+
# TODO: CLIP conversion doesn't work atm
452+
print ('Saving CLIP')
453+
text_enc_dict = {**text_enc_dict}
454+
455+
with open('l1.txt', 'w') as deb:
456+
text = '\n'.join(text_enc_dict.keys())
457+
deb.write(text)
438458

439459
if args.half:
440-
state_dict = {k: v.half() for k, v in state_dict.items()}
460+
text_enc_dict = {k: v.half() for k, v in text_enc_dict.items()}
441461

442462
if args.use_safetensors:
443-
save_file(state_dict, args.checkpoint_path)
463+
save_file(text_enc_dict, args.checkpoint_path)
444464
else:
445-
#state_dict = {"state_dict": state_dict}
446-
torch.save(state_dict, args.checkpoint_path)
447-
448-
# TODO: CLIP conversion doesn't work atm
449-
# print ('Saving CLIP')
450-
# state_dict = {**text_enc_dict}
451-
452-
# if args.half:
453-
# state_dict = {k: v.half() for k, v in state_dict.items()}
454-
455-
# if args.use_safetensors:
456-
# save_file(state_dict, args.checkpoint_path)
457-
# else:
458-
# #state_dict = {"state_dict": state_dict}
459-
# torch.save(state_dict, args.clip_checkpoint_path)
465+
#state_dict = {"state_dict": text_enc_dict}
466+
torch.save(text_enc_dict, args.clip_checkpoint_path)
460467

461468
print('Operation successfull')

0 commit comments

Comments
 (0)