Skip to content

Commit 84b15f5

Browse files
Alfie-Edwardsgeorgepaw
authored andcommitted
Allowing explicit deferral of setting replication_factor in IPUDataHandler
Summary: Support passing replication_factor=None into the IPUDataHandler constructor to explicitly defer setting replication_factor. Previously the best you could do was leave it as the default 1 but this sometime lead to errors validating the values. TF2.5 Only Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, harrym, christiana Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, christiana Subscribers: harrym Maniphest Tasks: T58709 Differential Revision: https://phabricator.sourcevertex.net/D63322
1 parent 0ea3168 commit 84b15f5

File tree

4 files changed

+75
-18
lines changed

4 files changed

+75
-18
lines changed

tensorflow/python/ipu/keras/extensions/data_adapter.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def _call_counter(token): #pylint: disable=missing-type-doc,missing-return-type
4848

4949

5050
class IPUDataHandler(data_adapter.DataHandler):
51-
"""Handles iterating over epoch-level Iterator objects on IPU."""
51+
"""Handles iterating over epoch-level Iterator objects on IPU.
52+
To defer setting the replication factor pass replication_factor=None,
53+
then set it later using set_replication_factor."""
5254
def __init__( # pylint: disable=super-init-not-called
5355
self,
5456
x,
@@ -121,22 +123,27 @@ def __init__( # pylint: disable=super-init-not-called
121123
"compatible with the IPU's requirement to know the data shape ahead "
122124
"of time. Please {}".format(type(x).__name__, hint))
123125

124-
dataset = self._get_and_post_process_dataset(class_weight)
125-
126-
self._replication_factor = replication_factor
127-
self._inferred_steps = self._infer_steps(steps_per_epoch, dataset)
126+
self._dataset = self._get_and_post_process_dataset(class_weight)
128127
self._steps_per_epoch = steps_per_epoch
128+
self._replication_factor = None
129+
self._inferred_steps = None
129130

130-
self._validate_dataset(dataset)
131-
132-
self._dataset = dataset
133131
self._current_step = 0
134132
self._step_increment = self._steps_per_execution_value - 1
135133
self._insufficient_data = False
136134

137-
self._validate_data_handler()
135+
if replication_factor is not None:
136+
self.set_replication_factor(replication_factor)
137+
138+
def _check_replication_factor_set(self, function_name):
139+
if self._replication_factor is None or self._infer_steps is None:
140+
raise RuntimeError(
141+
f"Cannot call {function_name} before the replication factor is set. "
142+
f"Either specify the replication factor in the constructor or set it "
143+
f"using set_replication_factor.")
138144

139145
def _validate_data_handler(self):
146+
self._check_replication_factor_set("_validate_data_handler")
140147
super()._validate_data_handler()
141148

142149
if self.steps_per_execution_value > self.inferred_steps:
@@ -212,6 +219,7 @@ def _get_and_post_process_dataset(self, class_weight):
212219
return dataset
213220

214221
def _validate_dataset(self, dataset):
222+
self._check_replication_factor_set("_validate_dataset")
215223
# Validate the size of the dataset.
216224
dataset_size = cardinality.cardinality(dataset)
217225
if dataset_size == cardinality.UNKNOWN:
@@ -263,6 +271,12 @@ def batch_size(self):
263271
return batch_size
264272

265273
def set_replication_factor(self, value):
274+
"""Set the replication factor and calculate inferred steps based on the
275+
replication factor.
276+
277+
Args:
278+
value (int): The value for the replication factor.
279+
"""
266280
self._replication_factor = value
267281
self._inferred_steps = self._infer_steps(self._steps_per_epoch,
268282
self._dataset)
@@ -271,6 +285,7 @@ def set_replication_factor(self, value):
271285

272286
def enumerate_epochs_with_reuse(self, manager, mode, infeed_kwargs):
273287
"""Yields `(epoch, InfeedQueue)`."""
288+
self._check_replication_factor_set("enumerate_epochs_with_reuse")
274289
with self._truncate_execution_to_epoch():
275290
data_iterator = manager.get_infeed(mode, self._dataset, infeed_kwargs)
276291
for epoch in range(self._initial_epoch, self._epochs):

tensorflow/python/ipu/keras/extensions/keras_extension_base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,7 +1214,8 @@ def _fit_delegate(self,
12141214
workers=workers,
12151215
use_multiprocessing=use_multiprocessing,
12161216
model=self,
1217-
steps_per_execution=self._steps_per_execution)
1217+
steps_per_execution=self._steps_per_execution,
1218+
replication_factor=None)
12181219

12191220
# Build the model with specific dtypes. This is important for models
12201221
# without explicit input dtypes (model subclasses and some sequential
@@ -1440,7 +1441,8 @@ def _evaluate_delegate(self,
14401441
workers=workers,
14411442
use_multiprocessing=use_multiprocessing,
14421443
model=self,
1443-
steps_per_execution=self._steps_per_execution)
1444+
steps_per_execution=self._steps_per_execution,
1445+
replication_factor=None)
14441446

14451447
# Build the model with specific dtypes. This is important for models
14461448
# without explicit input dtypes (model subclasses and some sequential
@@ -1597,7 +1599,8 @@ def _predict_delegate(self,
15971599
workers=workers,
15981600
use_multiprocessing=use_multiprocessing,
15991601
model=self,
1600-
steps_per_execution=self._steps_per_execution)
1602+
steps_per_execution=self._steps_per_execution,
1603+
replication_factor=None)
16011604

16021605
# Build the model with specific dtypes. This is important for models
16031606
# without explicit input dtypes (model subclasses and some sequential

tensorflow/python/ipu/tests/keras/extensions/data_adapter_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,19 @@ def test_steps_per_execution_not_divisor(self):
287287
x, epochs=1, batch_size=2, steps_per_execution=variables.Variable(3))
288288
del data_handler
289289

290+
def test_deferred_setting_of_replication_factor(self):
291+
x = np.ones((8, 1))
292+
x = dataset_ops.Dataset.from_tensor_slices(x).batch(2, drop_remainder=True)
293+
# With a batch size of 2, the steps_per_execution value of 3 is not valid
294+
# until we set a replication_factor which is a muliple of 3.
295+
data_handler = data_adapter.IPUDataHandler(
296+
x,
297+
epochs=1,
298+
batch_size=2,
299+
steps_per_execution=variables.Variable(3),
300+
replication_factor=None)
301+
data_handler.set_replication_factor(3)
302+
290303
def test_dataset_drop_batch_with_replication(self):
291304
x = np.ones((6, 1))
292305
x = dataset_ops.Dataset.from_tensor_slices(x).repeat()

tensorflow/python/ipu/tests/keras/keras_poprun_test.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from tensorflow.python.keras import Model, Input
3333
from tensorflow.python.keras.engine import data_adapter
3434
from tensorflow.python.keras.optimizer_v2 import gradient_descent
35-
from tensorflow.python.keras.callbacks import History
35+
from tensorflow.python.keras.callbacks import Callback, History
3636
from tensorflow.python.types.core import Tensor
3737

3838

@@ -210,39 +210,44 @@ def predict_on_full_dataset_without_fixed_size_with_fixed_steps(model):
210210

211211

212212
TESTCASES = [
213+
# The first arg is the name to use for the test case.
213214
(
215+
"fit_on_full_dataset_with_fixed_size_one_epoch",
214216
fit_on_full_dataset_with_fixed_size_one_epoch,
215217
True,
216218
),
217219
(
220+
"fit_on_full_dataset_with_fixed_size_two_epochs",
218221
fit_on_full_dataset_with_fixed_size_two_epochs,
219222
True,
220223
),
221224
(
225+
"evaluate_on_full_dataset_with_fixed_size",
222226
evaluate_on_full_dataset_with_fixed_size,
223227
False,
224228
),
225229
(
230+
"evaluate_on_full_dataset_with_fixed_size_with_fixed_steps",
226231
evaluate_on_full_dataset_with_fixed_size_with_fixed_steps,
227232
False,
228233
),
229234
(
235+
"evaluate_on_full_dataset_without_fixed_size_with_fixed_steps",
230236
evaluate_on_full_dataset_without_fixed_size_with_fixed_steps,
231237
False,
232238
),
233239
(
240+
"predict_on_full_dataset_with_fixed_size",
234241
predict_on_full_dataset_with_fixed_size,
235242
False,
236243
),
237244
(
238-
predict_on_full_dataset_with_fixed_size,
239-
False,
240-
),
241-
(
245+
"predict_on_full_dataset_with_fixed_size_with_fixed_steps",
242246
predict_on_full_dataset_with_fixed_size_with_fixed_steps,
243247
False,
244248
),
245249
(
250+
"predict_on_full_dataset_without_fixed_size_with_fixed_steps",
246251
predict_on_full_dataset_without_fixed_size_with_fixed_steps,
247252
False,
248253
),
@@ -326,7 +331,7 @@ def assert_result_is_equal(self, result_1, result_2):
326331
else:
327332
self.assertAllClose(result_1, result_2)
328333

329-
@parameterized.parameters(*TESTCASES)
334+
@parameterized.named_parameters(*TESTCASES)
330335
def test_popdist_horovod_are_equal(self, callback, did_weights_change):
331336
"""Tests whether the results of using keras from `keras_extensions_base`
332337
yields the same results as the upstream version after running a callback.
@@ -345,6 +350,27 @@ def test_popdist_horovod_are_equal(self, callback, did_weights_change):
345350
did_weights_change)
346351
self.assert_result_is_equal(popdist_result, horovod_result)
347352

353+
def test_popdist_dataset_truncation(self):
354+
class StepCounterCallback(Callback):
355+
def __init__(self):
356+
super().__init__()
357+
self.step_count = 0
358+
359+
def on_train_batch_begin(self, batch, logs=None):
360+
self.step_count += 1
361+
362+
def fn(model):
363+
cb = StepCounterCallback()
364+
# 5 batches of size 4.
365+
dataset = test_dataset(20, batch_size=4)
366+
model.fit(dataset, callbacks=[cb], verbose=False)
367+
return cb.step_count
368+
369+
_, result = run_with_popdist(fn)
370+
371+
# 5 batches should be truncated to 4 batches (2 instances, 2 per instance).
372+
self.assertEqual(result, 2)
373+
348374

349375
if __name__ == "__main__":
350376
test.main()

0 commit comments

Comments
 (0)