Skip to content

Commit 3a137c8

Browse files
committed
Allow user-specified num_slices
1 parent 8d6acdd commit 3a137c8

File tree

3 files changed

+8
-2
lines changed

3 files changed

+8
-2
lines changed

src/MaxText/checkpointing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,9 @@ def _restore_grain_iterator(
427427
elif expansion_factor_real_data > 1 and process_count_stored == process_count_jax // expansion_factor_real_data:
428428
# Scaling up to a larger number of hosts.(e.g., 32 files -> 64 processes)
429429
# In this case, a subset of hosts restore the data iterator.
430-
assert not isinstance(data_iterator, list), "when expansion_factor_real_data > 1, the data iterator should not be a list."
430+
assert not isinstance(
431+
data_iterator, list
432+
), "when expansion_factor_real_data > 1, the data iterator should not be a list."
431433
grain_restore_args = GrainCheckpointRestore(
432434
data_iterator.local_iterator, process_index=jax.process_index(), process_count=process_count_stored
433435
)

src/MaxText/configs/base.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,8 @@ ici_expert_parallelism: 1
471471
# Enable ZeRO-1 optimizer sharding over data axis
472472
shard_optimizer_over_data: False
473473

474-
# The number of TPU slices is automatically determined, you should not set this explicitly. For ahead of time compilation,
474+
# Unless explicitly specified, the number of TPU slices is automatically determined. It should only be set for
475+
# disaggregated reinforcement learning workloads using multiple slices. For ahead of time compilation,
475476
# you should set compile_toplogy_num_slices, which will in turn set this value. For non-TPU environments this is set to 1.
476477
num_slices: -1
477478

src/MaxText/max_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,9 @@ def _retrieve_jax_init_info(raw_keys):
289289

290290
def get_num_slices(raw_keys):
291291
"""Calculate num_slices based on number of devices."""
292+
if raw_keys["num_slices"] != -1:
293+
max_logging.log(f"Using num_slices={raw_keys['num_slices']} per user request.")
294+
return raw_keys["num_slices"]
292295
if raw_keys["hardware"] == "cpu":
293296
max_logging.log(" Setting num_slices=1 for CPU hardware type")
294297
return 1

0 commit comments

Comments
 (0)