Skip to content

Commit 733f924

Browse files
committed
Make sure gradient accumulation calls the right functions
Summary: Fix T43586 TF2.4 Only Test Plan: CI Reviewers: jackh, samuelh, alfiee, #tensorflow, simonl, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved Reviewed By: jackh, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved Maniphest Tasks: T43586 Differential Revision: https://phabricator.sourcevertex.net/D49238
1 parent 083c085 commit 733f924

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tensorflow/python/ipu/keras/optimizers/ipu_wrappers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Convenience wrappers for v2 optimizers
1717
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1818
"""
19+
from tensorflow.python.keras.optimizer_v1 import TFOptimizer
1920
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
2021
from tensorflow.python.training.optimizer import Optimizer
2122

@@ -169,8 +170,12 @@ def compute_gradients(self,
169170

170171
v = var_list if not self._model else self._model.trainable_weights
171172

172-
grads = self._optimizer.get_gradients(loss, v)
173-
grads_and_vars = zip(grads, v)
173+
if isinstance(self._optimizer, TFOptimizer):
174+
grads_and_vars = self._optimizer.get_grads(loss, v)
175+
else:
176+
grads = self._optimizer.get_gradients(loss, v)
177+
grads_and_vars = zip(grads, v)
178+
174179
return list(map(self.preprocess_gradients, grads_and_vars))
175180

176181
def apply_gradients(self, grads_and_vars, global_step=None, name=None):

0 commit comments

Comments
 (0)