Skip to content

Commit 72ef58e

Browse files
committed
Introduce Multislice RL
1 parent 0eb74b4 commit 72ef58e

File tree

3 files changed

+141
-104
lines changed

3 files changed

+141
-104
lines changed

src/MaxText/configs/rl.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ base_config: "base.yml"
2121
trainer_devices_fraction: 0.5
2222
sampler_devices_fraction: 0.5
2323
chips_per_vm: 4 # depends on hardware, for v5p this is 4
24+
num_trainer_slices: -1
25+
num_samplers_slices: -1
2426

2527
# ====== Reproducibility ======
2628
data_shuffle_seed: 42

src/MaxText/configs/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,6 +1218,8 @@ class RLHardware(BaseModel):
12181218
sampler_devices_fraction: float = Field(0.5, description="Fraction of devices to use for the sampler.")
12191219
chips_per_vm: int = Field(4, description="Number of accelerator chips per VM.")
12201220
use_pathways: bool = Field(True, description="Whether to use Pathways for multihost orchestration.")
1221+
num_trainer_slices: int = Field(-1, description="Number of slices for the trainer.")
1222+
num_samplers_slices: int = Field(-1, description="Number of slices for the samplers.")
12211223

12221224

12231225
class VLLM(BaseModel):

0 commit comments

Comments
 (0)