Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 5242ac6

Browse files
author
Ryan Sepassi
committed
Rm num_shards from Problem. Problems specify sharding themselves.
PiperOrigin-RevId: 163281576
1 parent c01617e commit 5242ac6

File tree

7 files changed

+23
-24
lines changed

7 files changed

+23
-24
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ mkdir -p $DATA_DIR $TMP_DIR $TRAIN_DIR
8686
t2t-datagen \
8787
--data_dir=$DATA_DIR \
8888
--tmp_dir=$TMP_DIR \
89-
--num_shards=100 \
9089
--problem=$PROBLEM
9190
9291
# Train

tensor2tensor/bin/t2t-datagen

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ flags.DEFINE_string("problem", "",
6363
"The name of the problem to generate data for.")
6464
flags.DEFINE_string("exclude_problems", "",
6565
"Comma-separates list of problems to exclude.")
66-
flags.DEFINE_integer("num_shards", 10, "How many shards to use.")
66+
flags.DEFINE_integer("num_shards", 0, "How many shards to use. Ignored for "
67+
"registered Problems.")
6768
flags.DEFINE_integer("max_cases", 0,
6869
"Maximum number of cases to generate (unbounded if 0).")
6970
flags.DEFINE_integer("random_seed", 429459, "Random seed to use.")
@@ -252,7 +253,7 @@ def generate_data_for_problem(problem):
252253
if isinstance(dev_gen, int):
253254
# The dev set and test sets are generated as extra shards using the
254255
# training generator. The integer specifies the number of training
255-
# shards. FLAGS.num_shards is ignored.
256+
# shards. FLAGS.num_shards is ignored.
256257
num_training_shards = dev_gen
257258
tf.logging.info("Generating data for %s.", problem)
258259
all_output_files = generator_utils.combined_data_filenames(
@@ -263,10 +264,11 @@ def generate_data_for_problem(problem):
263264
else:
264265
# usual case - train data and dev data are generated using separate
265266
# generators.
267+
num_shards = FLAGS.num_shards or 10
266268
tf.logging.info("Generating training data for %s.", problem)
267269
train_output_files = generator_utils.train_data_filenames(
268270
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
269-
FLAGS.num_shards)
271+
num_shards)
270272
generator_utils.generate_files(training_gen(), train_output_files,
271273
FLAGS.max_cases)
272274
tf.logging.info("Generating development data for %s.", problem)
@@ -282,11 +284,12 @@ def generate_data_for_problem(problem):
282284

283285
def generate_data_for_registered_problem(problem_name):
284286
tf.logging.info("Generating training data for %s.", problem_name)
287+
if FLAGS.num_shards:
288+
raise ValueError("--num_shards should not be set for registered Problem.")
285289
problem = registry.problem(problem_name)
286290
task_id = None if FLAGS.task_id < 0 else FLAGS.task_id
287291
problem.generate_data(os.path.expanduser(FLAGS.data_dir),
288292
os.path.expanduser(FLAGS.tmp_dir),
289-
num_shards=FLAGS.num_shards,
290293
task_id=task_id)
291294

292295

tensor2tensor/data_generators/algorithmic.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,7 @@ def dev_size(self):
6666
def num_shards(self):
6767
return 10
6868

69-
def generate_data(self, data_dir, _, num_shards=None, task_id=-1):
70-
if num_shards is None:
71-
num_shards = self.num_shards
72-
69+
def generate_data(self, data_dir, _, task_id=-1):
7370
def generator_eos(generator):
7471
"""Shift by NUM_RESERVED_IDS and append EOS token."""
7572
for case in generator:
@@ -87,7 +84,7 @@ def generator_eos(generator):
8784

8885
utils.generate_dataset_and_shuffle(
8986
train_generator_eos(),
90-
self.training_filepaths(data_dir, num_shards, shuffled=True),
87+
self.training_filepaths(data_dir, self.num_shards, shuffled=True),
9188
dev_generator_eos(),
9289
self.dev_filepaths(data_dir, 1, shuffled=True),
9390
shuffle=False)
@@ -254,7 +251,7 @@ def zipf_distribution(nbr_symbols, alpha):
254251

255252

256253
def zipf_random_sample(distr_map, sample_len):
257-
"""Helper function: Generate a random Zipf sample of given lenght.
254+
"""Helper function: Generate a random Zipf sample of given length.
258255
259256
Args:
260257
distr_map: list of float, Zipf's distribution over nbr_symbols.
@@ -287,7 +284,7 @@ def reverse_generator_nlplike(nbr_symbols,
287284
max_length: integer, maximum length of sequences to generate.
288285
nbr_cases: the number of cases to generate.
289286
scale_std_dev: float, Normal distribution's standard deviation scale factor
290-
used to draw the lenght of sequence. Default = 1% of the max_length.
287+
used to draw the length of sequence. Default = 1% of the max_length.
291288
alpha: float, Zipf's Law Distribution parameter. Default = 1.5.
292289
Usually for modelling natural text distribution is in
293290
the range [1.1-1.6].

tensor2tensor/data_generators/genetics.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,11 @@ def feature_encoders(self, data_dir):
8787
"targets": text_encoder.TextEncoder()
8888
}
8989

90-
def generate_data(self, data_dir, tmp_dir, num_shards=None, task_id=-1):
91-
if num_shards is None:
92-
num_shards = 100
90+
@property
91+
def num_shards(self):
92+
return 100
9393

94+
def generate_data(self, data_dir, tmp_dir, task_id=-1):
9495
try:
9596
# Download source data if download_url specified
9697
h5_filepath = generator_utils.maybe_download(tmp_dir, self.h5_file,
@@ -109,7 +110,7 @@ def generate_data(self, data_dir, tmp_dir, num_shards=None, task_id=-1):
109110
# Collect created shard processes to start and join
110111
processes = []
111112

112-
datasets = [(self.training_filepaths, num_shards, "train",
113+
datasets = [(self.training_filepaths, self.num_shards, "train",
113114
num_train_examples), (self.dev_filepaths, 1, "valid",
114115
num_dev_examples),
115116
(self.test_filepaths, 1, "test", num_test_examples)]
@@ -124,9 +125,10 @@ def generate_data(self, data_dir, tmp_dir, num_shards=None, task_id=-1):
124125
start_idx, end_idx))
125126
processes.append(p)
126127

127-
# Start and wait for processes in batches
128-
assert len(processes) == num_shards + 2 # 1 per training shard + dev + test
128+
# 1 per training shard + dev + test
129+
assert len(processes) == self.num_shards + 2
129130

131+
# Start and wait for processes in batches
130132
num_batches = int(
131133
math.ceil(float(len(processes)) / MAX_CONCURRENT_PROCESSES))
132134
for i in xrange(num_batches):

tensor2tensor/data_generators/image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def example_reading_spec(self, label_key=None):
338338
class ImageFSNS(ImageProblem):
339339
"""Problem spec for French Street Name recognition."""
340340

341-
def generate_data(self, data_dir, tmp_dir, num_shards=None, task_id=-1):
341+
def generate_data(self, data_dir, tmp_dir, task_id=-1):
342342
list_url = ("https://raw.githubusercontent.com/tensorflow/models/master/"
343343
"street/python/fsns_urls.txt")
344344
fsns_urls = generator_utils.maybe_download(

tensor2tensor/data_generators/problem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class Problem(object):
135135
# BEGIN SUBCLASS INTERFACE
136136
# ============================================================================
137137

138-
def generate_data(self, data_dir, tmp_dir, num_shards=None, task_id=-1):
138+
def generate_data(self, data_dir, tmp_dir, task_id=-1):
139139
raise NotImplementedError()
140140

141141
def hparams(self, defaults, model_hparams):

tensor2tensor/data_generators/wmt.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,10 @@ def vocab_name(self):
8383
def vocab_file(self):
8484
return "%s.%d" % (self.vocab_name, self.targeted_vocab_size)
8585

86-
def generate_data(self, data_dir, tmp_dir, num_shards=None, task_id=-1):
87-
if num_shards is None:
88-
num_shards = self.num_shards
86+
def generate_data(self, data_dir, tmp_dir, task_id=-1):
8987
generator_utils.generate_dataset_and_shuffle(
9088
self.train_generator(data_dir, tmp_dir, True),
91-
self.training_filepaths(data_dir, num_shards, shuffled=False),
89+
self.training_filepaths(data_dir, self.num_shards, shuffled=False),
9290
self.dev_generator(data_dir, tmp_dir),
9391
self.dev_filepaths(data_dir, 1, shuffled=False))
9492

0 commit comments

Comments
 (0)