66import os .path as osp
77import re
88
9- import torch
9+ import torch , gc
1010from safetensors .torch import load_file , save_file
1111
1212# =================#
@@ -369,6 +369,7 @@ def convert_text_enc_state_dict(text_enc_dict):
369369 parser .add_argument ("--model_path" , default = None , type = str , required = True , help = "Path to the model to convert." )
370370 parser .add_argument ("--checkpoint_path" , default = None , type = str , required = True , help = "Path to the output model." )
371371 parser .add_argument ("--clip_checkpoint_path" , default = None , type = str , help = "Path to the output CLIP model." )
372+ parser .add_argument ("--clip_base_path" , default = None , type = str , help = "Path to the source original ModelScope-format (!) CLIP model." )
372373 parser .add_argument ("--half" , action = "store_true" , help = "Save weights in half precision." )
373374 parser .add_argument (
374375 "--use_safetensors" , action = "store_true" , help = "Save weights use safetensors, default is ckpt."
@@ -380,7 +381,9 @@ def convert_text_enc_state_dict(text_enc_dict):
380381
381382 assert args .checkpoint_path is not None , "Must provide a checkpoint path!"
382383
383- assert args .clip_checkpoint_path is not None , "Must provide a CLIP checkpoint path!"
384+ assert args .clip_checkpoint_path is not None , "Must provide a CLIP checkpoint output path!"
385+
386+ assert args .clip_base_path is not None , "Must provide an existing original ModelScope format (!) CLIP checkpoint path!"
384387
385388 # Path for safetensors
386389 unet_path = osp .join (args .model_path , "unet" , "diffusion_pytorch_model.safetensors" )
@@ -394,31 +397,43 @@ def convert_text_enc_state_dict(text_enc_dict):
394397 unet_path = osp .join (args .model_path , "unet" , "diffusion_pytorch_model.bin" )
395398 unet_state_dict = torch .load (unet_path , map_location = "cpu" )
396399
397- # if osp.exists(vae_path):
398- # vae_state_dict = load_file(vae_path, device="cpu")
399- # else:
400- # vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
401- # vae_state_dict = torch.load(vae_path, map_location="cpu")
400+ # Convert the UNet model
401+ unet_state_dict = convert_unet_state_dict (unet_state_dict )
402+ #unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
403+
404+ print ('Saving UNET' )
405+ unet_state_dict = {** unet_state_dict }
406+
407+ if args .half :
408+ unet_state_dict = {k : v .half () for k , v in unet_state_dict .items ()}
409+
410+ if args .use_safetensors :
411+ save_file (unet_state_dict , args .checkpoint_path )
412+ else :
413+ #state_dict = {"state_dict": state_dict}
414+ torch .save (unet_state_dict , args .checkpoint_path )
415+
416+ del unet_state_dict
417+ gc .collect ()
418+ print ('UNET Saved' )
419+
420+ print ('Converting CLIP' )
402421
422+ # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
403423 if osp .exists (text_enc_path ):
404424 text_enc_dict = load_file (text_enc_path , device = "cpu" )
405425 else :
406426 text_enc_path = osp .join (args .model_path , "text_encoder" , "pytorch_model.bin" )
407427 text_enc_dict = torch .load (text_enc_path , map_location = "cpu" )
408428
409- # Convert the UNet model
410- unet_state_dict = convert_unet_state_dict ( unet_state_dict )
411- #unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
429+ with open ( 'l.txt' , 'w' ) as deb :
430+ text = ' \n ' . join ( text_enc_dict . keys () )
431+ deb . write ( text )
412432
413- # Convert the VAE model
414- # vae_state_dict = convert_vae_state_dict(vae_state_dict)
415- # vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
416-
417- # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
418433 is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
419434
435+ print (is_v20_model )
420436 if is_v20_model :
421-
422437 # MODELSCOPE always uses the 2.X encoder, btw --kabachuha
423438
424439 # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
@@ -429,33 +444,35 @@ def convert_text_enc_state_dict(text_enc_dict):
429444 text_enc_dict = convert_text_enc_state_dict (text_enc_dict )
430445 #text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
431446
432- # DON'T PUT TOGETHER FOR THE NEW CHECKPOINT AS MODELSCOPE USES THEM IN THE SPLITTED FORM --kabachuha
433- # Save CLIP and the Diffusion model to their own files
447+ clip_base_path = args .clip_base_path
434448
435- #state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
436- print ('Saving UNET' )
437- state_dict = {** unet_state_dict }
449+ # HACK: grab a preexisting openclip model and change only the converted layers
450+ #if osp.exists(clip_base_path):
451+ # text_enc_dict_base = load_file(clip_base_path, device="cpu")
452+ #else:
453+ text_enc_dict_base = torch .load (clip_base_path , map_location = "cpu" )
438454
439- if args .half :
440- state_dict = {k : v .half () for k , v in state_dict .items ()}
455+ print ('Changing the changed CLIP layers' )
456+ for k , v in text_enc_dict .items ():
457+ if k in text_enc_dict_base :
458+ text_enc_dict_base [k ] = v
441459
442- if args .use_safetensors :
443- save_file (state_dict , args .checkpoint_path )
444- else :
445- #state_dict = {"state_dict": state_dict}
446- torch .save (state_dict , args .checkpoint_path )
460+ text_enc_dict = text_enc_dict_base
447461
448- # TODO: CLIP conversion doesn't work atm
449- # print ('Saving CLIP')
450- # state_dict = {**text_enc_dict}
462+ print ('Saving CLIP' )
463+ text_enc_dict = {** text_enc_dict }
451464
452- # if args.half:
453- # state_dict = {k: v.half() for k, v in state_dict.items()}
465+ #with open('l1.txt', 'w') as deb:
466+ # text = '\n'.join(text_enc_dict.keys())
467+ # deb.write(text)
454468
455- # if args.use_safetensors:
456- # save_file(state_dict, args.checkpoint_path)
457- # else:
458- # #state_dict = {"state_dict": state_dict}
459- # torch.save(state_dict, args.clip_checkpoint_path)
469+ if args .half :
470+ text_enc_dict = {k : v .half () for k , v in text_enc_dict .items ()}
471+
472+ if args .use_safetensors :
473+ save_file (text_enc_dict , args .checkpoint_path )
474+ else :
475+ #state_dict = {"state_dict": text_enc_dict}
476+ torch .save (text_enc_dict , args .clip_checkpoint_path )
460477
461478 print ('Operation successfull' )
0 commit comments