@@ -78,6 +78,15 @@ def tts(model, text, p=0, speaker_id=None, fast=False):
7878 return waveform , alignment , spectrogram , mel
7979
8080
81+ def _load (checkpoint_path ):
82+ if use_cuda :
83+ checkpoint = torch .load (checkpoint_path )
84+ else :
85+ checkpoint = torch .load (checkpoint_path ,
86+ map_location = lambda storage , loc : storage )
87+ return checkpoint
88+
89+
8190if __name__ == "__main__" :
8291 args = docopt (__doc__ )
8392 print ("Command line args:\n " , args )
@@ -113,13 +122,13 @@ def tts(model, text, p=0, speaker_id=None, fast=False):
113122
114123 # Load checkpoints separately
115124 if checkpoint_postnet_path is not None and checkpoint_seq2seq_path is not None :
116- checkpoint = torch . load (checkpoint_seq2seq_path )
125+ checkpoint = _load (checkpoint_seq2seq_path )
117126 model .seq2seq .load_state_dict (checkpoint ["state_dict" ])
118- checkpoint = torch . load (checkpoint_postnet_path )
127+ checkpoint = _load (checkpoint_postnet_path )
119128 model .postnet .load_state_dict (checkpoint ["state_dict" ])
120129 checkpoint_name = splitext (basename (checkpoint_seq2seq_path ))[0 ]
121130 else :
122- checkpoint = torch . load (checkpoint_path )
131+ checkpoint = _load (checkpoint_path )
123132 model .load_state_dict (checkpoint ["state_dict" ])
124133 checkpoint_name = splitext (basename (checkpoint_path ))[0 ]
125134
0 commit comments