Skip to content

Commit 76fd210

Browse files
committed
Don't try and get length for inifinite datasets
Summary: Fix T57598 TF2.5 Only Test Plan: Added a test Reviewers: christiana, jackh, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved Reviewed By: christiana, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved Subscribers: harrym Maniphest Tasks: T57598 Differential Revision: https://phabricator.sourcevertex.net/D62644
1 parent 5b2b07d commit 76fd210

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

tensorflow/python/ipu/keras/extensions/data_adapter.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,14 @@ def _infer_steps(self, steps, dataset):
174174
"Could not infer the size of the data. You must specify the number "
175175
"of steps to run.")
176176
if steps % self._replication_factor:
177-
logging.warn(
178-
"Dataset of length {} is being evenly distributed between {} "
179-
"replicas. The remaining {} batch{} will be dropped.".format(
180-
len(dataset), self._replication_factor,
181-
steps % self._replication_factor,
182-
"es" if steps % self._replication_factor > 1 else ""))
177+
size = cardinality.cardinality(dataset)
178+
if size >= 0:
179+
logging.warn(
180+
"Dataset of length {} is being evenly distributed between {} "
181+
"replicas. The remaining {} batch{} will be dropped.".format(
182+
len(dataset), self._replication_factor,
183+
steps % self._replication_factor,
184+
"es" if steps % self._replication_factor > 1 else ""))
183185

184186
return int(steps // self._replication_factor)
185187

tensorflow/python/ipu/tests/keras/extensions/data_adapter_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,15 @@ def test_steps_per_execution_not_divisor(self):
287287
x, epochs=1, batch_size=2, steps_per_execution=variables.Variable(3))
288288
del data_handler
289289

290+
def test_dataset_drop_batch_with_replication(self):
291+
x = np.ones((6, 1))
292+
x = dataset_ops.Dataset.from_tensor_slices(x).repeat()
293+
data_handler = data_adapter.IPUDataHandler(x,
294+
steps_per_epoch=7,
295+
replication_factor=4)
296+
del data_handler
297+
# No exception raised.
298+
290299

291300
if __name__ == '__main__':
292301
ops.enable_eager_execution()

0 commit comments

Comments
 (0)