Skip to content

Commit c7086ef

Browse files
committed
Handle Checkpoints in strategy
Summary: This is due to upstream change 2c9ffb5 Ref T53780 TF2.5 Only Reviewers: alfiee, christiana, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved Reviewed By: alfiee, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved Maniphest Tasks: T53780 Differential Revision: https://phabricator.sourcevertex.net/D62093
1 parent c85ba55 commit c7086ef

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

tensorflow/python/ipu/ipu_multi_worker_strategy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from tensorflow.python.ops import control_flow_util
3434
from tensorflow.python.ops import control_flow_ops
3535
from tensorflow.python.ops import variable_scope
36+
from tensorflow.python.training.tracking import base as trackable
3637
from tensorflow.python.util import tf_contextlib
3738
from tensorflow.python.util import deprecation
3839

@@ -335,6 +336,8 @@ def initial_value_fn(): # pylint: disable=g-missing-docstring
335336
initial_value = kwargs["initial_value"]
336337
if callable(initial_value):
337338
initial_value = initial_value()
339+
if isinstance(initial_value, trackable.CheckpointInitialValue):
340+
initial_value = initial_value.wrapped_value
338341
assert not callable(initial_value)
339342
initial_value = ops.convert_to_tensor(initial_value,
340343
dtype=kwargs.get("dtype", None))

tensorflow/python/ipu/ipu_strategy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from tensorflow.python.keras.engine import sequential
3737
from tensorflow.python.ipu.ops import cross_replica_ops
3838
from tensorflow.python.ops import control_flow_ops
39+
from tensorflow.python.training.tracking import base as trackable
3940
from tensorflow.python.util import nest
4041
from tensorflow.python.ipu.keras.extensions import functional_extensions
4142
from tensorflow.python.ipu.keras.extensions import sequential_extensions
@@ -126,6 +127,8 @@ def initial_value_fn():
126127
initial_value = kwargs["initial_value"]
127128
if callable(initial_value):
128129
initial_value = initial_value()
130+
if isinstance(initial_value, trackable.CheckpointInitialValue):
131+
initial_value = initial_value.wrapped_value
129132
assert not callable(initial_value)
130133
return ops.convert_to_tensor(initial_value,
131134
dtype=kwargs.get("dtype", None))

0 commit comments

Comments
 (0)