File tree Expand file tree Collapse file tree 2 files changed +6
-0
lines changed Expand file tree Collapse file tree 2 files changed +6
-0
lines changed Original file line number Diff line number Diff line change 3333from tensorflow .python .ops import control_flow_util
3434from tensorflow .python .ops import control_flow_ops
3535from tensorflow .python .ops import variable_scope
36+ from tensorflow .python .training .tracking import base as trackable
3637from tensorflow .python .util import tf_contextlib
3738from tensorflow .python .util import deprecation
3839
@@ -335,6 +336,8 @@ def initial_value_fn(): # pylint: disable=g-missing-docstring
335336 initial_value = kwargs ["initial_value" ]
336337 if callable (initial_value ):
337338 initial_value = initial_value ()
339+ if isinstance (initial_value , trackable .CheckpointInitialValue ):
340+ initial_value = initial_value .wrapped_value
338341 assert not callable (initial_value )
339342 initial_value = ops .convert_to_tensor (initial_value ,
340343 dtype = kwargs .get ("dtype" , None ))
Original file line number Diff line number Diff line change 3636from tensorflow .python .keras .engine import sequential
3737from tensorflow .python .ipu .ops import cross_replica_ops
3838from tensorflow .python .ops import control_flow_ops
39+ from tensorflow .python .training .tracking import base as trackable
3940from tensorflow .python .util import nest
4041from tensorflow .python .ipu .keras .extensions import functional_extensions
4142from tensorflow .python .ipu .keras .extensions import sequential_extensions
@@ -126,6 +127,8 @@ def initial_value_fn():
126127 initial_value = kwargs ["initial_value" ]
127128 if callable (initial_value ):
128129 initial_value = initial_value ()
130+ if isinstance (initial_value , trackable .CheckpointInitialValue ):
131+ initial_value = initial_value .wrapped_value
129132 assert not callable (initial_value )
130133 return ops .convert_to_tensor (initial_value ,
131134 dtype = kwargs .get ("dtype" , None ))
You can’t perform that action at this time.
0 commit comments