Skip to content

Commit 17c268f

Browse files
tomsiadevnv-kkudrynski
authored andcommitted
[SIM/TF2] Fix concat bug from TensorFlow 2.11
1 parent 7475648 commit 17c268f

File tree

1 file changed

+4
-1
lines changed
  • TensorFlow2/Recommendation/SIM

1 file changed

+4
-1
lines changed

TensorFlow2/Recommendation/SIM/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,10 @@ def eval(model_fn, data_iterator, num_thresholds=8000, prefix=""):
253253
local = tf.constant(local)
254254

255255
# concat all local variables into a single tensor
256-
local = tf.concat(local, 0)
256+
if local is local_total_losses:
257+
local = tf.stack(local, 0)
258+
else:
259+
local = tf.concat(local, 0)
257260

258261
# for single element lists, tf.concat will produce shape=() instead of shape=(1,).
259262
# reshape it for hvd.allgather to work

0 commit comments

Comments
 (0)