Skip to content

Commit 0b24223

Browse files
Merge pull request #2667 from AI-Hypercomputer:sft_pathways
PiperOrigin-RevId: 831879716
2 parents ef64c73 + 29eddd9 commit 0b24223

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

docs/tutorials/sft_on_multi_host.md

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ The `docker_upload_runner.sh` script uploads your Docker image to Artifact Regis
5050
Install XPK by following the instructions in the [official documentation](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#installation-via-pip).
5151

5252
## 3. Create GKE cluster
53-
If you don't already have a GKE cluster, create one by following the [XPK cluster creation guide](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#cluster-create).
53+
If you don't already have a GKE cluster, create one by following the [XPK cluster creation guide](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#cluster-create). Ensure the cluster is Pathways-compatible when running SFT with Pathways.
5454

5555
## 4. Environment configuration
5656
```bash
@@ -89,6 +89,9 @@ If you already have a MaxText-compatible model checkpoint, simply set the follow
8989
```bash
9090
export MODEL_CHECKPOINT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
9191
```
92+
**Note:** Make sure that `MODEL_CHECKPOINT_PATH` has the checkpoints created using the correct storage flags:
93+
* **For SFT with McJAX:** `checkpoint_storage_use_zarr3=True` and `checkpoint_storage_use_ocdbt=True`.
94+
* **For SFT with Pathways:** `checkpoint_storage_use_zarr3=False` and `checkpoint_storage_use_ocdbt=False`.
9295

9396
### Option 2: Converting a Hugging Face checkpoint
9497
If your model checkpoint is from Hugging Face, you need to run a conversion script to make it MaxText-compatible.
@@ -102,6 +105,9 @@ export MODEL_CHECKPOINT_PATH=${OUTPUT_PATH}/${WORKLOAD_NAME}/maxtext-checkpoint/
102105
2. **Run the Conversion Script:** Execute the following command that downloads the specified Hugging Face model and converts its weights into the MaxText format. The conversion script only supports official versions of models from Hugging Face. To see the specific models and versions currently supported for conversion, please refer to the `HF_IDS` dictionary in the MaxText utility file [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py).
103106

104107
```bash
108+
USE_ZARR3=<Flag to use zarr3> # True to run SFT with McJAX, False to run SFT with Pathways
109+
USE_OCDBT=<Flag to use ocdbt> # True to run SFT with McJAX, False to run SFT with Pathways
110+
105111
xpk workload create \
106112
--cluster=${CLUSTER_NAME} \
107113
--project=${PROJECT} \
@@ -110,7 +116,7 @@ xpk workload create \
110116
--workload=ckpt-${WORKLOAD_NAME} \
111117
--tpu-type=${TPU_TYPE} \
112118
--num-slices=${TPU_SLICE} \
113-
--command "python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml model_name=$MODEL_NAME hf_access_token=$HF_TOKEN base_output_directory=$OUTPUT_PATH/$WORKLOAD_NAME/maxtext-checkpoint scan_layers=True"
119+
--command "python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml model_name=$MODEL_NAME hf_access_token=$HF_TOKEN base_output_directory=$OUTPUT_PATH/$WORKLOAD_NAME/maxtext-checkpoint scan_layers=True checkpoint_storage_use_zarr3=$USE_ZARR3 checkpoint_storage_use_ocdbt=$USE_OCDBT"
114120
```
115121

116122
## 6. Submit workload on GKE cluster
@@ -131,4 +137,16 @@ xpk workload create \
131137
Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`.
132138

133139
### 6.2. SFT with Pathways
134-
Pathways support is coming soon.
140+
```bash
141+
xpk workload create-pathways \
142+
--cluster=${CLUSTER_NAME} \
143+
--project=${PROJECT} \
144+
--zone=${ZONE} \
145+
--docker-image=${DOCKER_IMAGE} \
146+
--workload=${WORKLOAD_NAME} \
147+
--tpu-type=${TPU_TYPE} \
148+
--num-slices=${TPU_SLICE} \
149+
--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml run_name=$WORKLOAD_NAME base_output_directory=$OUTPUT_PATH model_name=$MODEL_NAME load_parameters_path=$MODEL_CHECKPOINT_PATH hf_access_token=$HF_TOKEN tokenizer_path=$TOKENIZER_PATH per_device_batch_size=1 steps=$STEPS profiler=xplane checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False enable_single_controller=True"
150+
```
151+
152+
Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`.

src/MaxText/sft/sft_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def get_tunix_config(mt_config):
8686
profiler_options = None
8787
if mt_config.profiler:
8888
set_profile_options = True
89-
if jax.extend.backend.get_backend().platform_version == "Pathways":
89+
platform_version = jax.extend.backend.get_backend().platform_version.strip()
90+
if platform_version.startswith("Pathways"):
9091
max_logging.log("Pathways backend detected. Disabling setting profile options.")
9192
set_profile_options = False
9293
profiler_options = profiler.ProfilerOptions(

0 commit comments

Comments
 (0)