File tree Expand file tree Collapse file tree 1 file changed +15
-2
lines changed Expand file tree Collapse file tree 1 file changed +15
-2
lines changed Original file line number Diff line number Diff line change @@ -832,8 +832,21 @@ def restore_parts(path, model):
832832 state = torch .load (path )["state_dict" ]
833833 model_dict = model .state_dict ()
834834 valid_state_dict = {k : v for k , v in state .items () if k in model_dict }
835- model_dict .update (valid_state_dict )
836- model .load_state_dict (model_dict )
835+
836+ try :
837+ model_dict .update (valid_state_dict )
838+ model .load_state_dict (model_dict )
839+ except RuntimeError as e :
840+ # there should be invalid size of weight(s), so load them per parameter
841+ print (str (e ))
842+ model_dict = model .state_dict ()
843+ for k , v in valid_state_dict .items ():
844+ model_dict [k ] = v
845+ try :
846+ model .load_state_dict (model_dict )
847+ except RuntimeError as e :
848+ print (str (e ))
849+ warn ("{}: may contain invalid size of weight. skipping..." .format (k ))
837850
838851
839852if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments