Skip to content

Commit 1be4f8e

Browse files
authored
Merge pull request #40 from r9y9/fix-non-deterministic
Fix non deterministic incremental inference
2 parents cda1ca4 + 709639f commit 1be4f8e

File tree

5 files changed

+82
-7
lines changed

5 files changed

+82
-7
lines changed

deepvoice3_pytorch/deepvoice3.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,8 +487,17 @@ def incremental_forward(self, encoder_out, text_positions, speaker_embed=None,
487487
return outputs, alignments, dones, decoder_states
488488

489489
def start_fresh_sequence(self):
490-
for conv in self.convolutions:
491-
conv.clear_buffer()
490+
_clear_modules(self.preattention)
491+
_clear_modules(self.convolutions)
492+
self.last_conv.clear_buffer()
493+
494+
495+
def _clear_modules(modules):
496+
for m in modules:
497+
try:
498+
m.clear_buffer()
499+
except AttributeError as e:
500+
pass
492501

493502

494503
class Converter(nn.Module):

deepvoice3_pytorch/nyanko.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ def incremental_forward(self, encoder_out, text_positions,
342342
def start_fresh_sequence(self):
343343
_clear_modules(self.audio_encoder_modules)
344344
_clear_modules(self.audio_decoder_modules)
345+
self.last_conv.clear_buffer()
345346

346347

347348
def _clear_modules(modules):

tests/data/ljspeech-mel-00001.npy

261 KB
Binary file not shown.

tests/test_deepvoice3.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919

2020
use_cuda = torch.cuda.is_available() and False
21+
torch.backends.cudnn.deterministic = True
2122
num_mels = 80
2223
num_freq = 513
2324
outputs_per_step = 4
@@ -145,13 +146,45 @@ def test_multi_speaker_deepvoice3():
145146
print("Done:", done.size())
146147

147148

148-
@attr("local_only")
149+
@attr("issue38")
150+
def test_incremental_path_multiple_times():
151+
texts = ["they discarded this for a more completely Roman and far less beautiful letter."]
152+
seqs = np.array([text_to_sequence(t) for t in texts])
153+
text_positions = np.arange(1, len(seqs[0]) + 1).reshape(1, len(seqs[0]))
154+
155+
r = 4
156+
mel_dim = 80
157+
sequence = Variable(torch.LongTensor(seqs))
158+
text_positions = Variable(torch.LongTensor(text_positions))
159+
160+
for model, speaker_ids in [
161+
(_get_model(force_monotonic_attention=False), None),
162+
(_get_model(force_monotonic_attention=False, n_speakers=32, speaker_embed_dim=16), Variable(torch.LongTensor([1])))]:
163+
model.eval()
164+
165+
# first call
166+
mel_outputs, linear_outputs, alignments, done = model(
167+
sequence, text_positions=text_positions, speaker_ids=speaker_ids)
168+
169+
# second call
170+
mel_outputs2, linear_outputs2, alignments2, done2 = model(
171+
sequence, text_positions=text_positions, speaker_ids=speaker_ids)
172+
173+
# Should get same result
174+
c = (mel_outputs - mel_outputs2).abs()
175+
print(c.mean(), c.max())
176+
177+
assert np.allclose(mel_outputs.cpu().data.numpy(),
178+
mel_outputs2.cpu().data.numpy(), atol=1e-5)
179+
180+
149181
def test_incremental_correctness():
150182
texts = ["they discarded this for a more completely Roman and far less beautiful letter."]
151183
seqs = np.array([text_to_sequence(t) for t in texts])
152184
text_positions = np.arange(1, len(seqs[0]) + 1).reshape(1, len(seqs[0]))
153185

154-
mel = np.load("/home/ryuichi/Dropbox/sp/deepvoice3_pytorch/data/ljspeech/ljspeech-mel-00035.npy")
186+
mel_path = join(dirname(__file__), "data", "ljspeech-mel-00001.npy")
187+
mel = np.load(mel_path)
155188
max_target_len = mel.shape[0]
156189
r = 4
157190
mel_dim = 80

tests/test_nyanko.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from deepvoice3_pytorch.builder import nyanko
1717
from deepvoice3_pytorch import MultiSpeakerTTSModel, AttentionSeq2Seq
1818

19-
use_cuda = torch.cuda.is_available()
19+
use_cuda = torch.cuda.is_available() and False
2020
num_mels = 80
2121
num_freq = 513
2222
outputs_per_step = 4
@@ -57,13 +57,45 @@ def test_nyanko_basics():
5757
mel_outputs, linear_outputs, alignments, done = model(x, y)
5858

5959

60-
@attr("local_only")
60+
@attr("issue38")
61+
def test_incremental_path_multiple_times():
62+
texts = ["they discarded this for a more completely Roman and far less beautiful letter."]
63+
seqs = np.array([text_to_sequence(t) for t in texts])
64+
text_positions = np.arange(1, len(seqs[0]) + 1).reshape(1, len(seqs[0]))
65+
66+
r = 1
67+
mel_dim = 80
68+
69+
sequence = Variable(torch.LongTensor(seqs))
70+
text_positions = Variable(torch.LongTensor(text_positions))
71+
72+
model = nyanko(n_vocab, mel_dim=mel_dim, linear_dim=513, downsample_step=4,
73+
r=r, force_monotonic_attention=False)
74+
model.eval()
75+
76+
# first call
77+
mel_outputs, linear_outputs, alignments, done = model(
78+
sequence, text_positions=text_positions, speaker_ids=None)
79+
80+
# second call
81+
mel_outputs2, linear_outputs2, alignments2, done2 = model(
82+
sequence, text_positions=text_positions, speaker_ids=None)
83+
84+
# Should get same result
85+
c = (mel_outputs - mel_outputs2).abs()
86+
print(c.mean(), c.max())
87+
88+
assert np.allclose(mel_outputs.cpu().data.numpy(),
89+
mel_outputs2.cpu().data.numpy(), atol=1e-5)
90+
91+
6192
def test_incremental_correctness():
6293
texts = ["they discarded this for a more completely Roman and far less beautiful letter."]
6394
seqs = np.array([text_to_sequence(t) for t in texts])
6495
text_positions = np.arange(1, len(seqs[0]) + 1).reshape(1, len(seqs[0]))
6596

66-
mel = np.load("/home/ryuichi/Dropbox/sp/deepvoice3_pytorch/data/ljspeech/ljspeech-mel-00035.npy")
97+
mel_path = join(dirname(__file__), "data", "ljspeech-mel-00001.npy")
98+
mel = np.load(mel_path)[::4]
6799
max_target_len = mel.shape[0]
68100
r = 1
69101
mel_dim = 80

0 commit comments

Comments
 (0)