File tree Expand file tree Collapse file tree 1 file changed +15
-2
lines changed Expand file tree Collapse file tree 1 file changed +15
-2
lines changed Original file line number Diff line number Diff line change @@ -823,8 +823,21 @@ def restore_parts(path, model):
823823 state = torch .load (path )["state_dict" ]
824824 model_dict = model .state_dict ()
825825 valid_state_dict = {k : v for k , v in state .items () if k in model_dict }
826- model_dict .update (valid_state_dict )
827- model .load_state_dict (model_dict )
826+
827+ try :
828+ model_dict .update (valid_state_dict )
829+ model .load_state_dict (model_dict )
830+ except RuntimeError as e :
831+ # there should be invalid size of weight(s), so load them per parameter
832+ print (str (e ))
833+ model_dict = model .state_dict ()
834+ for k , v in valid_state_dict .items ():
835+ model_dict [k ] = v
836+ try :
837+ model .load_state_dict (model_dict )
838+ except RuntimeError as e :
839+ print (str (e ))
840+ warn ("{}: may contain invalid size of weight. skipping..." .format (k ))
828841
829842
830843if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments