Skip to content

Commit 08dbe7c

Browse files
committed
A few cleanup of #54
1 parent 9bc4943 commit 08dbe7c

File tree

2 files changed

+14
-18
lines changed

2 files changed

+14
-18
lines changed

hparams.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
builder="deepvoice3",
2828

2929
# Must be configured depends on the dataset and model you use
30-
n_speakers=1,
30+
n_speakers=1,
3131
speaker_embed_dim=16,
3232

3333
# Audio:
@@ -80,7 +80,7 @@
8080

8181
# Data loader
8282
pin_memory=True,
83-
num_workers=2, # Set it to 1 when in Windows (MemoryError, THAllocator.c 0x5)
83+
num_workers=2, # Set it to 1 when in Windows (MemoryError, THAllocator.c 0x5)
8484

8585
# Loss
8686
masked_loss_weight=0.5, # (1-w)*loss + w * masked_loss
@@ -120,16 +120,14 @@
120120
# 0 tends to prevent word repretetion, but sometime causes skip words
121121
window_backward=1,
122122
power=1.4, # Power to raise magnitudes to prior to phase retrieval
123-
123+
124124
# GC:
125-
# Forced garbage collection probability
125+
# Forced garbage collection probability
126126
# Use only when MemoryError continues in Windows (Disabled by default)
127127
#gc_probability = 0.001,
128128
)
129129

130130

131-
132-
133131
def hparams_debug_string():
134132
values = hparams.values()
135133
hp = [' %s: %s' % (name, values[name]) for name in sorted(values)]

train.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
"""
2222
from docopt import docopt
2323

24-
import sys, gc, platform
24+
import sys
25+
import gc
26+
import platform
2527
from os.path import dirname, join
2628
from tqdm import tqdm, trange
2729
from datetime import datetime
@@ -134,14 +136,14 @@ def collect_features(self, *args):
134136
if _frontend is None:
135137
_frontend = getattr(frontend, hparams.frontend)
136138
seq = _frontend.text_to_sequence(text, p=hparams.replace_pronunciation_prob)
137-
139+
138140
if platform.system() == "Windows":
139141
if hasattr(hparams, 'gc_probability'):
140-
_frontend = None # memory leaking prevention in Windows
142+
_frontend = None # memory leaking prevention in Windows
141143
if np.random.rand() < hparams.gc_probability:
142-
gc.collect() # garbage collection enforced
144+
gc.collect() # garbage collection enforced
143145
print("GC done")
144-
146+
145147
if self.multi_speaker:
146148
return np.asarray(seq, dtype=np.int32), int(speaker_id)
147149
else:
@@ -723,7 +725,6 @@ def train(model, data_loader, optimizer, writer,
723725
if global_step > 0 and global_step % hparams.eval_interval == 0:
724726
eval_model(global_step, writer, model, checkpoint_dir, ismultispeaker)
725727

726-
727728
# Update
728729
loss.backward()
729730
if clip_thresh > 0:
@@ -888,15 +889,12 @@ def restore_parts(path, model):
888889
hparams.parse_json(f.read())
889890
# Override hyper parameters
890891
hparams.parse(args["--hparams"])
891-
892-
# Preventing Windows specific error such as MemoryError
892+
893+
# Preventing Windows specific error such as MemoryError
893894
# Also reduces the occurrence of THAllocator.c 0x05 error in Widows build of PyTorch
894895
if platform.system() == "Windows":
895896
print("Windows Detected - num_workers set to 1")
896-
hparams.set_hparam('num_workers',1)
897-
898-
# Now, print the finalized hparams.
899-
print(hparams_debug_string())
897+
hparams.set_hparam('num_workers', 1)
900898

901899
assert hparams.name == "deepvoice3"
902900
print(hparams_debug_string())

0 commit comments

Comments
 (0)