@@ -882,7 +882,7 @@ def _shard_features(self, features): # pylint: disable=missing-docstring
882882 v = tf .expand_dims (v , axis = - 1 )
883883 v_shape = [1 ]
884884 if v_shape == [1 ]:
885- v = tf .tile (v , [self ._num_datashards ])
885+ v = tf .tile (v , tf . to_int32 ( [self ._num_datashards ]) )
886886 sharded_features [k ] = self ._data_parallelism (
887887 tf .identity , tf .split (v , self ._num_datashards , 0 ))
888888 return sharded_features
@@ -1288,17 +1288,17 @@ def _create_host_call(model_dir):
12881288 graph = tf .get_default_graph ()
12891289 summaries = graph .get_collection (tf .GraphKeys .SUMMARIES )
12901290
1291- gs_t = tf .reshape (tf .train .get_global_step (), [1 ])
1291+ gs_t = tf .reshape (tf .to_int32 ( tf . train .get_global_step () ), [1 ])
12921292 summary_kwargs = collections .OrderedDict ()
12931293 for t in summaries :
12941294 if t .op .type != "ScalarSummary" :
12951295 continue
12961296
12971297 name = t .op .name
12981298 tensor = t .op .inputs [1 ]
1299- assert tensor .shape .is_compatible_with (
1300- []), ( "ScalarSummary %s must have shape [], but is: %s." %
1301- ( name , tensor . shape ) )
1299+ assert tensor .shape .is_compatible_with ([])
1300+ if tensor . dtype == tf . int64 :
1301+ tensor = tf . to_int32 ( tensor )
13021302 summary_kwargs [name ] = tf .reshape (tensor , [1 ])
13031303 summary_kwargs ["global_step" ] = gs_t
13041304
@@ -1312,7 +1312,7 @@ def host_call_fn(**kwargs):
13121312 Returns:
13131313 List of summary ops to run on the CPU host.
13141314 """
1315- gs = kwargs .pop ("global_step" )[0 ]
1315+ gs = tf . to_int64 ( kwargs .pop ("global_step" )[0 ])
13161316 with tf .contrib .summary .create_file_writer (model_dir ).as_default ():
13171317 with tf .contrib .summary .always_record_summaries ():
13181318 for name , value in sorted (six .iteritems (kwargs )):
0 commit comments