Skip to content

Commit 7393d5a

Browse files
hmellorgeorgepaw
authored andcommitted
Log effective batch size during training when gradient accumulation and/or replication is used
Summary: Unless the user has prior knowledge, it isn't obvious that the batch size provided to `fit()` (or `dataset.batch(batch_size)`) is per replica, rather than the effective batch size from an ML perspective. This source of confusion is added to when gradient accumulation is also used. This diff adds a log which states the effective batch size from the training optimizer's perspective. It dynamically explains the effective batch size by considering the following 3 possibilities: - gradient accumulation > 1 - log only mentions gradient accumulation - number of replicas > 1 - log only mentions replication - number of replicas > 1 and gradient accumulation > 1 - log mentions both TF2.5 Only Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, georgep, markf, vladimirm, christiana Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, christiana Subscribers: vladimirm Maniphest Tasks: T56300 Differential Revision: https://phabricator.sourcevertex.net/D61155
1 parent 8cbb17b commit 7393d5a

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

tensorflow/python/ipu/keras/extensions/data_adapter.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,15 @@ def steps_per_execution_value(self):
230230
def element_spec(self):
231231
return self._dataset.element_spec
232232

233+
@property
234+
def batch_size(self):
235+
batch_size = self._adapter.batch_size()
236+
if batch_size is None and self.element_spec:
237+
element_spec = nest.flatten(self.element_spec)[0]
238+
if element_spec.shape:
239+
batch_size = element_spec.shape[0]
240+
return batch_size
241+
233242
def set_replication_factor(self, value):
234243
self._replication_factor = value
235244
self._inferred_steps = self._infer_steps(self._steps_per_epoch,

tensorflow/python/ipu/keras/extensions/keras_extension_base.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,38 @@ def _log_steps_per_execution_warning(self, steps_per_execution):
139139
self.name, steps_per_execution))
140140
logged_steps_per_execution_warning = True
141141

142+
def _log_optimizer_batch_size(self, data_handler):
143+
"""A function that logs the batch size as seen from the perspective of the
144+
optimizer during training.
145+
146+
Args:
147+
data_handler (IPUDataHandler): The data handler created in `fit()`
148+
"""
149+
# Optimizer batch size depends on the specified batch size, the gradient
150+
# accumulation and the replication factor.
151+
steps_per_execution = data_handler.steps_per_execution_value
152+
gradient_accumulation_steps_per_replica = \
153+
self._verify_and_get_gradient_accumulation_steps_per_replica(
154+
steps_per_execution)
155+
total_replicas = self._get_replication_factor() * popdist.getNumInstances()
156+
# Construct tailored message depending on if replication, gradient
157+
# accunulation, or both are enabled.
158+
is_distributed = total_replicas > 1
159+
is_accumulated = gradient_accumulation_steps_per_replica > 1
160+
if is_accumulated or is_distributed:
161+
accumulating_n_batches = \
162+
" and accumulating {} batches per optimizer step".format(
163+
gradient_accumulation_steps_per_replica)
164+
across_n_replicas = " across {} replicas".format(total_replicas)
165+
effective_batch_size = data_handler.batch_size * \
166+
gradient_accumulation_steps_per_replica * total_replicas
167+
logging.info(
168+
"Training is{}{}{}, your effective batch size is {}.".format(
169+
" distributed" if is_distributed else "",
170+
accumulating_n_batches if is_accumulated else "",
171+
across_n_replicas if is_distributed else "",
172+
effective_batch_size))
173+
142174
def _get_shard_count(self):
143175
"""Returns how many shards the model is parallelized over.
144176
@@ -1196,6 +1228,8 @@ def _fit_delegate(self,
11961228
replication_factor = self._get_replication_factor()
11971229
data_handler.set_replication_factor(replication_factor)
11981230

1231+
self._log_optimizer_batch_size(data_handler)
1232+
11991233
# Container that configures and calls `tf.keras.Callback`s.
12001234
if not isinstance(callbacks, callbacks_module.CallbackList):
12011235
callbacks = callbacks_module.CallbackList(

0 commit comments

Comments
 (0)