|
21 | 21 | """ |
22 | 22 | from docopt import docopt |
23 | 23 |
|
24 | | -import sys, gc, platform |
| 24 | +import sys |
| 25 | +import gc |
| 26 | +import platform |
25 | 27 | from os.path import dirname, join |
26 | 28 | from tqdm import tqdm, trange |
27 | 29 | from datetime import datetime |
@@ -134,14 +136,14 @@ def collect_features(self, *args): |
134 | 136 | if _frontend is None: |
135 | 137 | _frontend = getattr(frontend, hparams.frontend) |
136 | 138 | seq = _frontend.text_to_sequence(text, p=hparams.replace_pronunciation_prob) |
137 | | - |
| 139 | + |
138 | 140 | if platform.system() == "Windows": |
139 | 141 | if hasattr(hparams, 'gc_probability'): |
140 | | - _frontend = None # memory leaking prevention in Windows |
| 142 | + _frontend = None # memory leaking prevention in Windows |
141 | 143 | if np.random.rand() < hparams.gc_probability: |
142 | | - gc.collect() # garbage collection enforced |
| 144 | + gc.collect() # garbage collection enforced |
143 | 145 | print("GC done") |
144 | | - |
| 146 | + |
145 | 147 | if self.multi_speaker: |
146 | 148 | return np.asarray(seq, dtype=np.int32), int(speaker_id) |
147 | 149 | else: |
@@ -723,7 +725,6 @@ def train(model, data_loader, optimizer, writer, |
723 | 725 | if global_step > 0 and global_step % hparams.eval_interval == 0: |
724 | 726 | eval_model(global_step, writer, model, checkpoint_dir, ismultispeaker) |
725 | 727 |
|
726 | | - |
727 | 728 | # Update |
728 | 729 | loss.backward() |
729 | 730 | if clip_thresh > 0: |
@@ -888,15 +889,12 @@ def restore_parts(path, model): |
888 | 889 | hparams.parse_json(f.read()) |
889 | 890 | # Override hyper parameters |
890 | 891 | hparams.parse(args["--hparams"]) |
891 | | - |
892 | | - # Preventing Windows specific error such as MemoryError |
| 892 | + |
| 893 | + # Preventing Windows specific error such as MemoryError |
893 | 894 | # Also reduces the occurrence of THAllocator.c 0x05 error in Widows build of PyTorch |
894 | 895 | if platform.system() == "Windows": |
895 | 896 | print("Windows Detected - num_workers set to 1") |
896 | | - hparams.set_hparam('num_workers',1) |
897 | | - |
898 | | - # Now, print the finalized hparams. |
899 | | - print(hparams_debug_string()) |
| 897 | + hparams.set_hparam('num_workers', 1) |
900 | 898 |
|
901 | 899 | assert hparams.name == "deepvoice3" |
902 | 900 | print(hparams_debug_string()) |
|
0 commit comments