Skip to content

Commit fb76fd0

Browse files
committed
Improve SFT documentation
1 parent e3ddb1a commit fb76fd0

File tree

3 files changed

+95
-58
lines changed

3 files changed

+95
-58
lines changed

docs/tutorials/sft.md

Lines changed: 60 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,81 +14,98 @@
1414
limitations under the License.
1515
-->
1616

17-
# Try SFT
17+
# Supervised Fine-Tuning (SFT) on Single-Host TPUs
1818
Supervised fine-tuning (SFT) is a process where a pre-trained large language model is fine-tuned on a labeled dataset to adapt the model to perform better on specific tasks.
1919

20-
This tutorial demonstrates step-by-step instructions for setting up the environment and then training the Llama3.1 8B model on the [HuggingFaceH4/ultrachat_200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) dataset using SFT. If you wish to use a different dataset, you can [update the dataset configurations](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/sft.yml).
20+
This tutorial demonstrates step-by-step instructions for setting up the environment and then training the model on the Hugging Face dataset using SFT.
2121

2222
We use [Tunix](https://github.com/google/tunix), a JAX-based library designed for post-training tasks, to perform SFT.
2323

2424
In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get started!
2525

2626
## Setup virtual environment
2727

28-
### Create a Python3.12 virtual environment
28+
### Create a Python3.12 virtual environment & Install MaxText dependencies
2929
```sh
30+
git clone https://github.com/google/maxtext.git
31+
cd maxtext
3032
bash tools/setup/setup.sh
3133
```
3234

3335
### Activate virtual environment
3436
```
3537
# Replace with your virtual environment name if not using this default name
3638
venv_name="maxtext_venv"
37-
source ~/$venv_name/bin/activate
39+
source $venv_name/bin/activate
3840
```
3941

40-
### Install MaxText dependencies
41-
```
42-
bash tools/setup/setup.sh
42+
## Setup environment variables
43+
Set the following environment variables before running SFT.
44+
```sh
45+
# -- Model configuration --
46+
export PRE_TRAINED_MODEL=<model name> # e.g., 'llama3.1-8b'
47+
export PRE_TRAINED_MODEL_TOKENIZER=<tokenizer path> # e.g., 'meta-llama/Llama-3.1-8B-Instruct'
48+
export HF_TOKEN=<Hugging Face access token>
49+
50+
# -- MaxText configuration --
51+
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
52+
export RUN_NAME=<name for this run> # e.g., $(date +%Y-%m-%d-%H-%M-%S)
53+
export STEPS=<number of fine-tuning steps to run> # e.g., 1000
54+
export PER_DEVICE_BATCH_SIZE=<batch size per device> # e.g., 1
55+
56+
# -- Dataset configuration --
57+
export DATASET_NAME=<Hugging Face dataset name> # e.g., HuggingFaceH4/ultrachat_200k
58+
export TRAIN_SPLIT=<data split for train> # e.g., train_sft
59+
export TRAIN_DATA_COLUMNS=<data columns to train on> # e.g., ['messages']
4360
```
4461

45-
## Run SFT
46-
There are two scenarios supported for running SFT:
47-
1. **Run SFT on Hugging Face checkpoint**
48-
Download the checkpoint directly from Hugging Face and fine-tune it using SFT.
49-
50-
2. **Run SFT on MaxText checkpoint**
51-
Use a checkpoint generated by MaxText and fine-tune it using SFT.
62+
## Get your model checkpoint
63+
This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint.
5264

53-
Choose the scenario that matches your workflow and follow the corresponding instructions below.
65+
### Option 1: Using an existing MaxText checkpoint
66+
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.
5467

55-
### Run SFT on Hugging Face checkpoint
56-
* The script will first convert a Hugging Face checkpoint to a MaxText checkpoint.
57-
* It then runs SFT on this converted checkpoint.
58-
59-
#### Setup environment variables
68+
```sh
69+
export PRE_TRAINED_MODEL_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
6070
```
61-
export HF_TOKEN=<Hugging Face access token>
6271

63-
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs>
72+
### Option 2: Converting a Hugging Face checkpoint
73+
If your model checkpoint is from Hugging Face, you need to run a conversion script to make it MaxText-compatible.
6474

65-
export STEPS=<number of fine-tuning steps to run>
75+
1. **Set the Output Path:** First, define where the new MaxText checkpoint will be saved.
6676

67-
export PER_DEVICE_BATCH_SIZE=1
77+
```sh
78+
export PRE_TRAINED_MODEL_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/maxtext-checkpoint/0/items
6879
```
6980

70-
Finally, run the [script](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/llama3.1/8b/run_sft.sh):
71-
```
72-
bash ~/maxtext/end_to_end/tpu/llama3.1/8b/run_sft.sh
73-
```
81+
2. **Run the Conversion Script:** Execute the following command that downloads the specified Hugging Face model and converts its weights into the MaxText format. The conversion script only supports official versions of models from Hugging Face. To see the specific models and versions currently supported for conversion, please refer to the `HF_IDS` dictionary in the MaxText utility file [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py).
7482

75-
### Run SFT on MaxText checkpoint
76-
* The script directly runs SFT on MaxText checkpoint.
83+
```sh
84+
pip install torch # Ensure torch is installed for the conversion script
7785

78-
#### Setup environment variables
86+
python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \
87+
model_name=${PRE_TRAINED_MODEL} \
88+
hf_access_token=${HF_TOKEN} \
89+
base_output_directory=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/maxtext-checkpoint \
90+
scan_layers=True
7991
```
80-
export HF_TOKEN=<Hugging Face access token>
81-
82-
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs>
83-
84-
export STEPS=<number of fine-tuning steps to run>
8592

86-
export PER_DEVICE_BATCH_SIZE=1
93+
## Run SFT on Hugging Face Dataset
94+
Now you are ready to run SFT using the following command:
8795

88-
export PRE_TRAINED_MODEL_CKPT_PATH=<gcs path for MaxText checkpoint>
89-
```
90-
91-
Finally, run the [script](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/llama3.1/8b/run_sft.sh):
92-
```
93-
bash ~/maxtext/end_to_end/tpu/llama3.1/8b/run_sft.sh
96+
```sh
97+
python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml \
98+
run_name=${RUN_NAME} \
99+
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
100+
model_name=${PRE_TRAINED_MODEL} \
101+
load_parameters_path=${PRE_TRAINED_MODEL_CKPT_PATH} \
102+
hf_access_token=${HF_TOKEN} \
103+
tokenizer_path=${PRE_TRAINED_MODEL_TOKENIZER} \
104+
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
105+
steps=${STEPS} \
106+
hf_path=${DATASET_NAME} \
107+
train_split=${TRAIN_SPLIT} \
108+
train_data_columns=${TRAIN_DATA_COLUMNS} \
109+
profiler=xplane
94110
```
111+
Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`.

docs/tutorials/sft_on_multi_host.md

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,14 @@
1414
# limitations under the License.
1515
-->
1616

17-
# Supervised Fine-Tuning (SFT) with Deepseek-V3 model
18-
This guide provides step by step instructions to run SFT with Deepseek-V3 model on TPU v6e-256. Deepseek-V3 is a Mixture-of-Experts (MoE) language model with 671B parameters.
17+
# Supervised Fine-Tuning (SFT) on Multi-Host TPUs
18+
Supervised fine-tuning (SFT) is a process where a pre-trained large language model is fine-tuned on a labeled dataset to adapt the model to perform better on specific tasks.
19+
20+
This tutorial demonstrates step-by-step instructions for setting up the multi-host TPU environment and then training the model on the Hugging Face dataset using SFT. In this tutorial we use a multi-host TPU such as `v6e-256`.
21+
22+
We use [Tunix](https://github.com/google/tunix), a JAX-based library designed for post-training tasks, to perform SFT.
23+
24+
Let's get started!
1925

2026
## 1. Build and upload MaxText Docker image
2127
This section guides you through cloning the MaxText repository, building MaxText Docker image with dependencies, and uploading the docker image to your project's Artifact Registry.
@@ -28,7 +34,7 @@ cd maxtext
2834

2935
### 1.2. Build MaxText Docker image
3036
```bash
31-
bash dependencies/scripts/docker_build_dependency_image.sh MODE=jax_ai_image BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
37+
bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training
3238
```
3339
This creates a local Docker image named `maxtext_base_image`.
3440

@@ -44,7 +50,7 @@ The `docker_upload_runner.sh` script uploads your Docker image to Artifact Regis
4450
Install XPK by following the instructions in the [official documentation](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#installation-via-pip).
4551

4652
## 3. Create GKE cluster
47-
If you don't already have a GKE cluster with a `v6e-256` TPU slice available, create one by following the [XPK cluster creation guide](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#cluster-create).
53+
If you don't already have a GKE cluster, create one by following the [XPK cluster creation guide](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#cluster-create).
4854

4955
## 4. Environment configuration
5056
```bash
@@ -54,20 +60,32 @@ export CLUSTER_NAME=<Name of GKE Cluster>
5460
export ZONE=<GKE Cluster Zone>
5561

5662
# -- Workload Configuration --
57-
export WORKLOAD_NAME="sft-$(date +%Y-%m-%d-%H-%M-%S)" # Or your desired workload name
58-
export TPU_TYPE=v6e-256
63+
export WORKLOAD_NAME=<Name of Workload> # e.g., $(date +%Y-%m-%d-%H-%M-%S)
64+
export TPU_TYPE=<TPU Type> # e.g., v6e-256
5965
export TPU_SLICE=1
6066
export DOCKER_IMAGE="gcr.io/${PROJECT}/${DOCKER_IMAGE_NAME}"
6167

6268
# -- MaxText Configuration --
63-
export OUTPUT_PATH=<GCS Bucket Path for output/logs>
64-
export STEPS=100 # Number of fine-tuning steps to run
65-
export HF_TOKEN=<Hugging Face access token>
66-
export MODEL_CHECKPOINT_PATH=<GCS path to model checkpoint>
69+
export OUTPUT_PATH=<GCS Path for Output/Logs> # e.g., gs://my-bucket/my-output-directory
70+
export STEPS=<Fine-Tuning Steps> # e.g., 1000
71+
export HF_TOKEN=<Hugging Face Access Token>
72+
73+
# -- Model Configuration --
74+
export MODEL_NAME=<Model Name> # e.g., deepseek3-671b
75+
export TOKENIZER_PATH=<Model Tokenizer> # e.g., deepseek-ai/DeepSeek-V3
76+
export MODEL_CHECKPOINT_PATH=<GCS Path to Model Checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
77+
78+
# -- Dataset configuration --
79+
export DATASET_NAME=<Hugging Face Dataset Name> # e.g., HuggingFaceH4/ultrachat_200k
80+
export TRAIN_SPLIT=<Data Split for Train> # e.g., train_sft
81+
export TRAIN_DATA_COLUMNS=<Data Columns to Train on> # e.g., ['messages']
6782
```
6883

6984
## 5. Submit workload on GKE cluster
70-
This section provides the command to run SFT with Deepseek-v3 model on a v6e-256 GKE cluster.
85+
This section provides the command to run SFT on a GKE cluster.
86+
87+
### 5.1. SFT with Multi-Controller JAX (McJAX)
88+
7189
```bash
7290
xpk workload create \
7391
--cluster=${CLUSTER_NAME} \
@@ -77,7 +95,9 @@ xpk workload create \
7795
--workload=${WORKLOAD_NAME} \
7896
--tpu-type=${TPU_TYPE} \
7997
--num-slices=${TPU_SLICE} \
80-
--command "python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml run_name=$WORKLOAD_NAME base_output_directory=$OUTPUT_PATH model_name=deepseek3-671b load_parameters_path=$MODEL_CHECKPOINT_PATH hf_access_token=$HF_TOKEN tokenizer_path=deepseek-ai/DeepSeek-V3 per_device_batch_size=1 steps=$STEPS profiler=xplane megablox=False sparse_matmul=False ici_expert_parallelism=16 ici_fsdp_parallelism=16 weight_dtype=bfloat16 dtype=bfloat16 remat_policy=full decoder_layer_input=offload sa_block_q=2048 sa_block_q_dkv=2048 sa_block_q_dq=2048 opt_type=sgd attention=flash capacity_factor=1.0 max_target_length=2048"
98+
--command "python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml run_name=$WORKLOAD_NAME base_output_directory=$OUTPUT_PATH model_name=$MODEL_NAME load_parameters_path=$MODEL_CHECKPOINT_PATH hf_access_token=$HF_TOKEN tokenizer_path=$TOKENIZER_PATH per_device_batch_size=1 steps=$STEPS profiler=xplane hf_path=$DATASET_NAME train_split=$TRAIN_SPLIT train_data_columns=$TRAIN_DATA_COLUMNS"
8199
```
82-
Once the fine-tuning is completed, you can access your model checkpoint at `${OUTPUT_PATH}/${WORKLOAD_NAME}/checkpoints/${STEPS}/model_params`.
100+
Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`.
83101

102+
### 5.2. SFT with Pathways
103+
Pathways support is coming soon.

src/MaxText/examples/README_how_to_run_examples.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ Use the link for Jupyter Lab as a link for "Connect to a local runtime" in Colla
122122

123123
### Supervised Fine-Tuning (SFT)
124124

125-
- **`sft_qwen3_demo.ipynb`** → Qwen3-0.6B with Hugging Face ultrachat_200k dataset
126-
- **`sft_llama3_demo.ipynb`** → Llama3.1-8B with Hugging Face ultrachat_200k dataset
125+
- **`sft_qwen3_demo.ipynb`** → Qwen3-0.6B SFT training and evaluation on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k)
126+
- **`sft_llama3_demo.ipynb`** → Llama3.1-8B SFT training on [Hugging Face ultrachat_200k dataset](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
127127

128128
### GRPO Training
129129

0 commit comments

Comments
 (0)