Skip to content

Commit 30b04e9

Browse files
committed
Fix for CPU
1 parent 708a241 commit 30b04e9

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

synthesis.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,15 @@ def tts(model, text, p=0, speaker_id=None, fast=False):
7878
return waveform, alignment, spectrogram, mel
7979

8080

81+
def _load(checkpoint_path):
82+
if use_cuda:
83+
checkpoint = torch.load(checkpoint_path)
84+
else:
85+
checkpoint = torch.load(checkpoint_path,
86+
map_location=lambda storage, loc: storage)
87+
return checkpoint
88+
89+
8190
if __name__ == "__main__":
8291
args = docopt(__doc__)
8392
print("Command line args:\n", args)
@@ -113,13 +122,13 @@ def tts(model, text, p=0, speaker_id=None, fast=False):
113122

114123
# Load checkpoints separately
115124
if checkpoint_postnet_path is not None and checkpoint_seq2seq_path is not None:
116-
checkpoint = torch.load(checkpoint_seq2seq_path)
125+
checkpoint = _load(checkpoint_seq2seq_path)
117126
model.seq2seq.load_state_dict(checkpoint["state_dict"])
118-
checkpoint = torch.load(checkpoint_postnet_path)
127+
checkpoint = _load(checkpoint_postnet_path)
119128
model.postnet.load_state_dict(checkpoint["state_dict"])
120129
checkpoint_name = splitext(basename(checkpoint_seq2seq_path))[0]
121130
else:
122-
checkpoint = torch.load(checkpoint_path)
131+
checkpoint = _load(checkpoint_path)
123132
model.load_state_dict(checkpoint["state_dict"])
124133
checkpoint_name = splitext(basename(checkpoint_path))[0]
125134

0 commit comments

Comments
 (0)