Skip to content

Commit 5b2b07d

Browse files
hmellorgeorgepaw
authored andcommitted
Warn if steps_per_execution truncated, error if not divisor of inferred_steps
Summary: Reinstates the more informative error ([here in SDK2.4](https://phabricator.sourcevertex.net/diffusion/TENSORFLOW/browse/poplar%252Fr2.4%252Frelease/tensorflow/python/ipu/keras/extensions/data_adapter.py$210)) that lets the user know how they have incorrectly set `steps_per_execution` (if it isn't a divisor of the `inferred_steps` of the dataset). Currently, the error thrown is an Infeed overflow error, which isn't very clear. Additionally, log a warning if truncation has occurred due to `steps_per_execution` being larger than the `inferred_steps` of the dataset. TF2.5 Only Test Plan: Test added to ensure that the error is thrown under the necessary circumstances. Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, markf, christiana Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, christiana Subscribers: christiana Maniphest Tasks: T55559 Differential Revision: https://phabricator.sourcevertex.net/D60817
1 parent 66e8b99 commit 5b2b07d

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

tensorflow/compiler/plugin/poplar/tests/distributed_tf2_test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,16 +164,19 @@ def test_tf2_distributed_popdist_strategy(self):
164164
strategy = popdist_strategy.PopDistStrategy()
165165

166166
with strategy.scope():
167-
dataset = test_dataset(popdist.getNumInstances() * batch_size *
167+
dataset = test_dataset(popdist.getNumTotalReplicas() * batch_size *
168168
steps_to_run,
169169
batch_size=batch_size)
170+
dataset = dataset.shard(num_shards=popdist.getNumInstances(),
171+
index=popdist.getInstanceIndex())
172+
steps_per_execution = len(dataset) // popdist.getNumLocalReplicas()
170173
model = simple_model()
171174
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
172175
loss_fn = tf.keras.losses.MeanSquaredError()
173176

174177
model.compile(optimizer=optimizer,
175178
loss=loss_fn,
176-
steps_per_execution=popdist.getNumTotalReplicas())
179+
steps_per_execution=steps_per_execution)
177180

178181
# Build the model separately so we can assert that the biases are
179182
# broadcasted properly before training.

tensorflow/python/ipu/keras/extensions/data_adapter.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,15 @@ def __init__( # pylint: disable=super-init-not-called
139139
def _validate_data_handler(self):
140140
super()._validate_data_handler()
141141

142+
if self.steps_per_execution_value > self.inferred_steps:
143+
logging.warn(
144+
"`steps_per_execution` has been set to {} but the dataset "
145+
"provided{} only contains {} batches. Using {} as "
146+
"`steps_per_execution`.".format(
147+
self.steps_per_execution_value,
148+
" to this replica" if self._replication_factor > 1 else "",
149+
self.inferred_steps, self.inferred_steps))
150+
142151
with self._truncate_execution_to_epoch():
143152
if self.inferred_steps == 0:
144153
steps_per_replica = math.ceil(
@@ -151,6 +160,11 @@ def _validate_data_handler(self):
151160
len(self._dataset), self._replication_factor,
152161
steps_per_replica,
153162
steps_per_replica * self._replication_factor))
163+
elif self.inferred_steps % self._steps_per_execution != 0:
164+
raise ValueError(
165+
"`steps_per_execution` must be a divisor of the number of batches "
166+
"in the dataset provided{}.".format(
167+
" to this replica" if self._replication_factor > 1 else ""))
154168

155169
def _infer_steps(self, steps, dataset):
156170
"""Infers steps_per_epoch needed to loop through a dataset."""

tensorflow/python/ipu/tests/keras/extensions/data_adapter_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from absl.testing import parameterized
1818
import numpy as np
1919

20+
from tensorflow.python.ops import variables
2021
from tensorflow.python.data.experimental.ops import cardinality
2122
from tensorflow.python.data.ops import dataset_ops
2223
from tensorflow.python.framework import ops
@@ -275,6 +276,17 @@ def test_dataset_not_big_enough_with_replication(self):
275276
replication_factor=4)
276277
del data_handler
277278

279+
def test_steps_per_execution_not_divisor(self):
280+
x = np.ones((8, 1))
281+
x = dataset_ops.Dataset.from_tensor_slices(x).batch(2, drop_remainder=True)
282+
with self.assertRaisesRegex(
283+
ValueError,
284+
r"`steps_per_execution` must be a divisor of the number of batches in "
285+
r"the dataset provided."):
286+
data_handler = data_adapter.IPUDataHandler(
287+
x, epochs=1, batch_size=2, steps_per_execution=variables.Variable(3))
288+
del data_handler
289+
278290

279291
if __name__ == '__main__':
280292
ops.enable_eager_execution()

0 commit comments

Comments
 (0)