@@ -815,12 +815,21 @@ def build_model():
815815 return model
816816
817817
818+ def _load (checkpoint_path ):
819+ if use_cuda :
820+ checkpoint = torch .load (checkpoint_path )
821+ else :
822+ checkpoint = torch .load (checkpoint_path ,
823+ map_location = lambda storage , loc : storage )
824+ return checkpoint
825+
826+
818827def load_checkpoint (path , model , optimizer , reset_optimizer ):
819828 global global_step
820829 global global_epoch
821830
822831 print ("Load checkpoint from: {}" .format (path ))
823- checkpoint = torch . load (path )
832+ checkpoint = _load (path )
824833 model .load_state_dict (checkpoint ["state_dict" ])
825834 if not reset_optimizer :
826835 optimizer_state = checkpoint ["optimizer" ]
@@ -834,15 +843,15 @@ def load_checkpoint(path, model, optimizer, reset_optimizer):
834843
835844
836845def _load_embedding (path , model ):
837- state = torch . load (path )["state_dict" ]
846+ state = _load (path )["state_dict" ]
838847 key = "seq2seq.encoder.embed_tokens.weight"
839848 model .seq2seq .encoder .embed_tokens .weight .data = state [key ]
840849
841850
842851# https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3
843852def restore_parts (path , model ):
844853 print ("Restore part of the model from: {}" .format (path ))
845- state = torch . load (path )["state_dict" ]
854+ state = _load (path )["state_dict" ]
846855 model_dict = model .state_dict ()
847856 valid_state_dict = {k : v for k , v in state .items () if k in model_dict }
848857 model_dict .update (valid_state_dict )
@@ -951,7 +960,8 @@ def restore_parts(path, model):
951960 # Setup summary writer for tensorboard
952961 if log_event_path is None :
953962 if platform .system () == "Windows" :
954- log_event_path = "log/run-test" + str (datetime .now ()).replace (" " , "_" ).replace (":" ,"_" )
963+ log_event_path = "log/run-test" + \
964+ str (datetime .now ()).replace (" " , "_" ).replace (":" , "_" )
955965 else :
956966 log_event_path = "log/run-test" + str (datetime .now ()).replace (" " , "_" )
957967 print ("Los event path: {}" .format (log_event_path ))
0 commit comments