Skip to content

Commit 3d6b929

Browse files
committed
fix gather global error
1 parent 151fa9f commit 3d6b929

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/forward_pass_logit_checker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def main(config, test_args): # pylint: disable=W0621
279279
rngs={"aqt": init_rng},
280280
)
281281

282-
full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits)
282+
full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits, tiled=True)
283283
# if full_train_logits shape is [num_hosts, batch_size, seq_len, vocab_size]
284284
if full_train_logits.ndim == 4:
285285
full_train_logits = jnp.reshape(full_train_logits, (-1, config.max_target_length, config.vocab_size))

0 commit comments

Comments
 (0)