Skip to content

Commit ed38dd2

Browse files
committed
Add sanity check for input and decoder target lengths
This should give more user-friendly error message than that reported in #33
1 parent 421e8b7 commit ed38dd2

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

train.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,14 @@ def train(model, data_loader, optimizer, writer,
596596
input_lengths = input_lengths.long().numpy()
597597
decoder_lengths = target_lengths.long().numpy() // r // downsample_step
598598

599+
max_seq_len = max(input_lengths.max(), decoder_lengths.max())
600+
if max_seq_len >= hparams.max_positions:
601+
raise RuntimeError(
602+
"""max_seq_len ({}) >= max_posision ({})
603+
Input text or decoder targget length exceeded the maximum length.
604+
Please set a larger value for ``max_position`` in hyper parameters.""".format(
605+
max_seq_len, hparams.max_positions))
606+
599607
# Feed data
600608
x, mel, y = Variable(x), Variable(mel), Variable(y)
601609
text_positions = Variable(text_positions)

0 commit comments

Comments
 (0)