@@ -139,7 +139,7 @@ def call(self, features):
139139 sharded_logits , losses = self .model_fn_sharded (sharded_features )
140140 if isinstance (sharded_logits , dict ):
141141 concat_logits = {}
142- for k , v in sharded_logits .iteritems ():
142+ for k , v in six .iteritems (sharded_logits ):
143143 concat_logits [k ] = tf .concat (v , 0 )
144144 return concat_logits , losses
145145 else :
@@ -172,7 +172,7 @@ def model_fn_sharded(self, sharded_features):
172172 if isinstance (body_out , dict ):
173173 sharded_logits = {}
174174 sharded_losses = {}
175- for k , v in body_out .iteritems ():
175+ for k , v in six .iteritems (body_out ):
176176 sharded_logits [k ] = dp (self .top , v , datashard_to_features )
177177 sharded_losses [k ] = dp (self .loss , sharded_logits [k ],
178178 datashard_to_features )
@@ -190,8 +190,8 @@ def model_fn_sharded(self, sharded_features):
190190 else :
191191 sharded_logits , sharded_losses = dp (self .model_fn , datashard_to_features )
192192 if isinstance (sharded_logits [0 ], dict ):
193- temp_dict = {k : [] for k , _ in sharded_logits [0 ]. iteritems ( )}
194- for k , _ in sharded_logits [0 ]. iteritems ( ):
193+ temp_dict = {k : [] for k , _ in six . iteritems ( sharded_logits [0 ])}
194+ for k , _ in six . iteritems ( sharded_logits [0 ]):
195195 for l in sharded_logits :
196196 temp_dict [k ].append (l [k ])
197197 sharded_logits = temp_dict
@@ -328,7 +328,7 @@ def top(self, body_output, features):
328328 "The keys of model_body's returned logits dict must match the keys "
329329 "of problem_hparams.target_modality's dict." )
330330 logits = {}
331- for k , v in body_output .iteritems ():
331+ for k , v in six .iteritems (body_output ):
332332 with tf .variable_scope (k ): # TODO(aidangomez): share variables here?
333333 logits [k ] = self ._top_single (v , target_modality [k ], features )
334334 return logits
@@ -362,7 +362,7 @@ def loss(self, logits, features):
362362 "The keys of model_body's returned logits dict must match the keys "
363363 "of problem_hparams.target_modality's dict." )
364364 losses = {}
365- for k , v in logits .iteritems ():
365+ for k , v in six .iteritems (logits ):
366366 losses [k ] = self ._loss_single (v , target_modality [k ], features )
367367 return tf .add_n ([n / d for n , d in losses .values ()])
368368 else :
@@ -927,7 +927,7 @@ def estimator_model_fn(cls,
927927 # Set known shapes
928928 if use_tpu :
929929 if isinstance (logits , dict ):
930- for k , v in logits .iteritems ():
930+ for k , v in six .iteritems (logits ):
931931 if "scalar/" in k :
932932 continue
933933
0 commit comments