File tree Expand file tree Collapse file tree 1 file changed +17
-3
lines changed Expand file tree Collapse file tree 1 file changed +17
-3
lines changed Original file line number Diff line number Diff line change @@ -847,15 +847,29 @@ def _load_embedding(path, model):
847847 key = "seq2seq.encoder.embed_tokens.weight"
848848 model .seq2seq .encoder .embed_tokens .weight .data = state [key ]
849849
850-
851850# https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3
851+
852+
852853def restore_parts (path , model ):
853854 print ("Restore part of the model from: {}" .format (path ))
854855 state = _load (path )["state_dict" ]
855856 model_dict = model .state_dict ()
856857 valid_state_dict = {k : v for k , v in state .items () if k in model_dict }
857- model_dict .update (valid_state_dict )
858- model .load_state_dict (model_dict )
858+
859+ try :
860+ model_dict .update (valid_state_dict )
861+ model .load_state_dict (model_dict )
862+ except RuntimeError as e :
863+ # there should be invalid size of weight(s), so load them per parameter
864+ print (str (e ))
865+ model_dict = model .state_dict ()
866+ for k , v in valid_state_dict .items ():
867+ model_dict [k ] = v
868+ try :
869+ model .load_state_dict (model_dict )
870+ except RuntimeError as e :
871+ print (str (e ))
872+ warn ("{}: may contain invalid size of weight. skipping..." .format (k ))
859873
860874
861875if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments