66import os .path as osp
77import re
88
9- import torch
9+ import torch , gc
1010from safetensors .torch import load_file , save_file
1111
1212# =================#
@@ -400,25 +400,43 @@ def convert_text_enc_state_dict(text_enc_dict):
400400 # vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
401401 # vae_state_dict = torch.load(vae_path, map_location="cpu")
402402
403+ # Convert the UNet model
404+ unet_state_dict = convert_unet_state_dict (unet_state_dict )
405+ #unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
406+
407+ print ('Saving UNET' )
408+ unet_state_dict = {** unet_state_dict }
409+
410+ if args .half :
411+ unet_state_dict = {k : v .half () for k , v in unet_state_dict .items ()}
412+
413+ if args .use_safetensors :
414+ save_file (unet_state_dict , args .checkpoint_path )
415+ else :
416+ #state_dict = {"state_dict": state_dict}
417+ torch .save (unet_state_dict , args .checkpoint_path )
418+
419+ del unet_state_dict
420+ gc .collect ()
421+ print ('UNET Saved' )
422+
423+ print ('Converting CLIP' )
424+
425+ # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
403426 if osp .exists (text_enc_path ):
404427 text_enc_dict = load_file (text_enc_path , device = "cpu" )
405428 else :
406429 text_enc_path = osp .join (args .model_path , "text_encoder" , "pytorch_model.bin" )
407430 text_enc_dict = torch .load (text_enc_path , map_location = "cpu" )
408431
409- # Convert the UNet model
410- unet_state_dict = convert_unet_state_dict (unet_state_dict )
411- #unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
412-
413- # Convert the VAE model
414- # vae_state_dict = convert_vae_state_dict(vae_state_dict)
415- # vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
432+ with open ('l.txt' , 'w' ) as deb :
433+ text = '\n ' .join (text_enc_dict .keys ())
434+ deb .write (text )
416435
417- # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
418436 is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
419437
438+ print (is_v20_model )
420439 if is_v20_model :
421-
422440 # MODELSCOPE always uses the 2.X encoder, btw --kabachuha
423441
424442 # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
@@ -429,33 +447,22 @@ def convert_text_enc_state_dict(text_enc_dict):
429447 text_enc_dict = convert_text_enc_state_dict (text_enc_dict )
430448 #text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
431449
432- # DON'T PUT TOGETHER FOR THE NEW CHECKPOINT AS MODELSCOPE USES THEM IN THE SPLITTED FORM --kabachuha
433- # Save CLIP and the Diffusion model to their own files
434450
435- #state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
436- print ('Saving UNET' )
437- state_dict = {** unet_state_dict }
451+ # TODO: CLIP conversion doesn't work atm
452+ print ('Saving CLIP' )
453+ text_enc_dict = {** text_enc_dict }
454+
455+ with open ('l1.txt' , 'w' ) as deb :
456+ text = '\n ' .join (text_enc_dict .keys ())
457+ deb .write (text )
438458
439459 if args .half :
440- state_dict = {k : v .half () for k , v in state_dict .items ()}
460+ text_enc_dict = {k : v .half () for k , v in text_enc_dict .items ()}
441461
442462 if args .use_safetensors :
443- save_file (state_dict , args .checkpoint_path )
463+ save_file (text_enc_dict , args .checkpoint_path )
444464 else :
445- #state_dict = {"state_dict": state_dict}
446- torch .save (state_dict , args .checkpoint_path )
447-
448- # TODO: CLIP conversion doesn't work atm
449- # print ('Saving CLIP')
450- # state_dict = {**text_enc_dict}
451-
452- # if args.half:
453- # state_dict = {k: v.half() for k, v in state_dict.items()}
454-
455- # if args.use_safetensors:
456- # save_file(state_dict, args.checkpoint_path)
457- # else:
458- # #state_dict = {"state_dict": state_dict}
459- # torch.save(state_dict, args.clip_checkpoint_path)
465+ #state_dict = {"state_dict": text_enc_dict}
466+ torch .save (text_enc_dict , args .clip_checkpoint_path )
460467
461468 print ('Operation successfull' )
0 commit comments