Skip to content

Commit fa7bdcc

Browse files
Merge pull request #2702 from AI-Hypercomputer:xfgu-multislice-rl
PiperOrigin-RevId: 833880778
2 parents aca5b24 + 72ef58e commit fa7bdcc

File tree

3 files changed

+141
-104
lines changed

3 files changed

+141
-104
lines changed

src/MaxText/configs/rl.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ base_config: "base.yml"
2121
trainer_devices_fraction: 0.5
2222
sampler_devices_fraction: 0.5
2323
chips_per_vm: 4 # depends on hardware, for v5p this is 4
24+
num_trainer_slices: -1
25+
num_samplers_slices: -1
2426

2527
# ====== Reproducibility ======
2628
data_shuffle_seed: 42

src/MaxText/configs/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,6 +1223,8 @@ class RLHardware(BaseModel):
12231223
sampler_devices_fraction: float = Field(0.5, description="Fraction of devices to use for the sampler.")
12241224
chips_per_vm: int = Field(4, description="Number of accelerator chips per VM.")
12251225
use_pathways: bool = Field(True, description="Whether to use Pathways for multihost orchestration.")
1226+
num_trainer_slices: int = Field(-1, description="Number of slices for the trainer.")
1227+
num_samplers_slices: int = Field(-1, description="Number of slices for the samplers.")
12261228

12271229

12281230
class VLLM(BaseModel):

0 commit comments

Comments
 (0)