Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,17 @@ RUN echo "Installing Post-Training dependencies (vLLM, tpu-common, tunix) with M


# Uninstall existing jax to avoid conflicts
RUN pip uninstall -y jax jaxlib libtpu

RUN pip install aiohttp==3.12.15

# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.
RUN pip install keyring keyrings.google-artifactregistry-auth

RUN pip install numba==0.61.2

# Install vLLM for Jax and TPUs from the artifact registry
RUN VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \
--index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
--extra-index-url https://pypi.org/simple/ \
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu

# Install tpu-commons from the artifact registry
RUN pip install --no-cache-dir --pre \
--index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
--extra-index-url https://pypi.org/simple/ \
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
tpu-commons==0.1.2
RUN uv pip uninstall -y jax jaxlib libtpu

RUN uv pip install aiohttp==3.12.15

RUN uv pip install numba==0.61.2

# Install vLLM for Jax and TPUs
RUN uv pip install vllm-tpu

RUN if [ "$MODE" = "post-training-experimental" ]; then \
pip uninstall -y jax jaxlib libtpu && \
pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
uv pip uninstall -y jax jaxlib libtpu && \
uv pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
uv pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
fi
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

ARG BASEIMAGE
FROM ${BASEIMAGE}
ARG MODE
ENV MODE=$MODE

RUN echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}"
RUN pip uninstall -y jax jaxlib libtpu

RUN pip install aiohttp==3.12.15

# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.
RUN pip install keyring keyrings.google-artifactregistry-auth

RUN pip install numba==0.61.2

COPY tunix /tunix
RUN pip install -e /tunix --no-cache-dir


COPY vllm /vllm
RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir --pre \
--extra-index-url https://pypi.org/simple/ \
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html


COPY tpu-inference /tpu-inference
RUN pip install -e /tpu-inference --no-cache-dir --pre \
--extra-index-url https://pypi.org/simple/ \
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html


RUN if [ "$MODE" = "post-training-experimental" ]; then \
echo "MODE=post-training-experimental: Re-installing JAX/libtpu"; \
pip uninstall -y jax jaxlib libtpu && \
pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
fi
57 changes: 37 additions & 20 deletions dependencies/scripts/docker_build_dependency_image.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
Expand All @@ -20,14 +20,15 @@
# bash docker_build_dependency_image.sh MODE=nightly
# bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13
# Nightly build with JAX_VERSION for GPUs. Available versions listed at https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax:
# bash docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109 # Note: this sets both jax-nightly and jaxlib-nightly
# bash docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109 # Note: this sets both jax-nightly and jaxlib-nightly
# MODE=custom_wheels is the same as nightly except that it reinstalls any
# additional wheels that are present in the maxtext directory.
# The main use case is to install custom jax or jaxlib wheels but it also
# works with any custom wheels.
# bash docker_build_dependency_image.sh MODE=custom_wheels

# bash docker_build_dependency_image.sh MODE=post-training
# bash docker_build_dependency_image.sh MODE=post-training POST_TRAINING_SOURCE=local

if [ "${BASH_SOURCE-}" ]; then
this_file="${BASH_SOURCE[0]}"
Expand Down Expand Up @@ -97,6 +98,12 @@ if [[ -z ${DEVICE} ]]; then
echo "Default DEVICE=${DEVICE}"
fi

# New flag for post-training source
if [[ -z ${POST_TRAINING_SOURCE} ]]; then
export POST_TRAINING_SOURCE=remote # Default to the original Dockerfile
echo "Default POST_TRAINING_SOURCE=${POST_TRAINING_SOURCE}"
fi

# Function to build with MODE=jax_ai_image
build_ai_image() {
if [[ -z ${BASEIMAGE+x} ]]; then
Expand Down Expand Up @@ -171,24 +178,34 @@ if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then
exit 1
fi

# # To install tpu_commons from a local path, we copy it into the build context, excluding __pycache__.
# # This assumes vllm, tunix, tpu_commons is a sibling directory to the current one (maxtext).
# rsync -a --exclude='__pycache__' ../tpu_commons .
# # To install vllm from a local path, we copy it into the build context, excluding __pycache__.
# # This assumes vllm is a sibling directory to the current one (maxtext).
# rsync -a --exclude='__pycache__' ../vllm .

# rsync -a --exclude='__pycache__' ../tunix .

# # The cleanup is set to run even if the build fails to remove the copied directory.
# trap "rm -rf ./tpu_commons ./vllm ./tunix" EXIT INT TERM

docker build \
--network host \
--build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \
--build-arg MODE=${MODE} \
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile' \
-t ${LOCAL_IMAGE_NAME} .
DOCKERFILE_NAME=""
if [[ ${POST_TRAINING_SOURCE} == "local" ]] ; then

# To install tpu-inference from a local path, we copy it into the build context, excluding __pycache__.
# This assumes vllm, tunix, tpu-inference is a sibling directory to the current one (maxtext).
rsync -a --exclude='__pycache__' ../tpu-inference .
# To install vllm from a local path, we copy it into the build context, excluding __pycache__.
# This assumes vllm is a sibling directory to the current one (maxtext).
rsync -a --exclude='__pycache__' ../vllm .

rsync -a --exclude='__pycache__' ../tunix .

# The cleanup is set to run even if the build fails to remove the copied directory.
trap "rm -rf ./tpu-inference ./vllm ./tunix" EXIT INT TERM

DOCKERFILE_NAME='maxtext_post_training_local_dependencies.Dockerfile'
echo "Using local post-training dependencies Dockerfile: $DOCKERFILE_NAME"
else
DOCKERFILE_NAME='maxtext_post_training_dependencies.Dockerfile'
echo "Using remote post-training dependencies Dockerfile: $DOCKERFILE_NAME"
fi

docker build \
--network host \
--build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \
--build-arg MODE=${MODE} \
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/'"$DOCKERFILE_NAME" \
-t ${LOCAL_IMAGE_NAME} .
fi

if [[ ${CUSTOM_JAX} -eq 1 ]] ; then
Expand Down
41 changes: 13 additions & 28 deletions docs/tutorials/grpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,20 @@ And we use vLLM as the library for efficient model inference and generation.

In this tutorial we use a single host TPUVM such as `v6e-8/v5p-8`. Let's get started!

## Setup your virtual environment
## Create virtual environment and Install MaxText dependencies
Follow instructions in [Install MaxText](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/install_maxtext.md), but
recommend creating the virtual environment outside the `maxtext` directory.

### Create a Python3.12 venv if not already pre-existing and install MaxText dependencies
```sh
bash tools/setup/setup.sh
```

### Activate your virtual environment (Skip if you have already done this for running `bash tools/setup/setup.sh` )
```
# Replace with your virtual environment name if not using this default name
venv_name="maxtext_venv"
source ~/$venv_name/bin/activate
```

## vLLM and tpu-commons installations
## vLLM and tpu-inference installations

Next, run the following bash script to get all the necessary installations inside the virtual environment.
Next, run the following bash script to get all the necessary installations inside the virtual environment (for e.g., `maxtext_venv`).
This will take few minutes. Follow along the installation logs and look out for any issues!

```
bash ~/maxtext/src/MaxText/examples/install_tunix_vllm_requirement.sh
```

1. It installs `pip install keyring keyrings.google-artifactregistry-auth` which enables pip to authenticate with Google Artifact Registry automatically.
2. Next, it installs `vLLM` for Jax and TPUs from the artifact registry `https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/`
3. Then, it installs `tpu-commons` from the same artifact registry.

`tpu_commons` is the TPU backend for vLLM. You will need both libraries to run vLLM on tpus.
We use the scheduler code from vLLM, and the model runner code from `tpu_commons`
Primarily, it installs `vllm-tpu` which is [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby providing TPU inference for vLLM, with unified JAX and PyTorch support.


## Run GRPO
Expand All @@ -62,15 +47,15 @@ Finally, run the command

```
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
--model_name=llama3.1-8b \
--tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
--load_parameters_path=gs://path/to/checkpoint/0/items \
--run_name=$WORKLOAD \
--base_output_directory=$OUTPUT_PATH \
--hf_access_token=$HF_TOKEN
model_name=llama3.1-8b \
tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
load_parameters_path=gs://path/to/checkpoint/0/items \
run_name=$WORKLOAD \
base_output_directory=$OUTPUT_PATH \
hf_access_token=$HF_TOKEN
```

The overview of the demo script is as follows:
The overview of the what this run will do is as follows:

1. We load a policy model and a reference model. Both are copies of `Llama3.1-8b-Instruct`.
2. Evaluate the policy model's performance on GSM8K math reasoning benchmark.
Expand Down
37 changes: 19 additions & 18 deletions docs/tutorials/grpo_with_pathways.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,28 @@ We use Tunix as the library for GRPO.
And we use vLLM as the library for efficient model inference and generation.

Furthermore, we use Pathways for [orchestration](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro). Using Pathways, you can also run GRPO in a disaggregated mode where the trainer and the samplers are running on separate mesh. Try out the following recipe `v5p-64`. You can submit jobs to a Pathways enabled GKE cluster.

## Build and Upload MaxText Docker Image with Tunix, vLLM, tpu-commons dependencies
Run the following bash script to create a docker image with all the dependencies of MaxText, Tunix, vLLM and tpu-commons installed.

In addition to MaxText dependencies,

1. It installs `pip install keyring keyrings.google-artifactregistry-auth` which enables pip to authenticate with Google Artifact Registry automatically.
2. Next, it installs `vLLM` for Jax and TPUs from the artifact registry `https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/`
3. Then, it installs `tpu-commons` from the same artifact registry.
## Create virtual environment and Install MaxText dependencies
Follow instructions in [Install MaxText](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/install_maxtext.md), but
recommend creating the virtual environment outside the `maxtext` directory.

## Build and Upload MaxText Docker Image with Tunix, vLLM, tpu-inference dependencies

`tpu_commons` is the TPU backend for vLLM. You will need both libraries to run vLLM on tpus.
We use the scheduler code from vLLM, and the model runner code from `tpu_commons`
### Installing stable releases of tunix and vllm-tpu
Run the following bash script to create a docker image with all the dependencies of MaxText, Tunix, vLLM and tpu-inference installed.

In addition to MaxText dependencies, primarily, it installs `vllm-tpu` which is [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby providing TPU inference for vLLM, with unified JAX and PyTorch support.

```
bash docker_build_dependency_image.sh MODE=post-training
bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training
```

You can also use `bash docker_build_dependency_image.sh MODE=post-training-experimental` to try out new features via experimental dependencies such as improved pathwaysutils resharding API
You can also use `bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training-experimental` to try out new features via experimental dependencies such as improved pathwaysutils resharding API

### Install from locally git cloned repo's

You can also locally git clone [tunix](https://github.com/google/tunix), [tpu-inference](https://github.com/vllm-project/tpu-inference), [vllm](https://github.com/vllm-project/vllm.git) and then use the following command to build a docker image using them:
`bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training POST_TRAINING_SOURCE=local`

### Upload the dependency docker image along with MaxText code
```
Expand All @@ -61,12 +62,12 @@ xpk workload create-pathways --workload $WORKLOAD \
--project=$PROJECT_ID --priority=high \
--command "HF_TOKEN=$HF_TOKEN TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' # Llama3.1-70B-Instruct
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
--model_name=llama3.1-70b \
--tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \
--load_parameters_path=gs://path/to/checkpoint/0/items \
--run_name=$WORKLOAD \
--base_output_directory=$OUTPUT_PATH \
--hf_access_token=$HF_TOKEN"
model_name=llama3.1-70b \
tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \
load_parameters_path=gs://path/to/checkpoint/0/items \
run_name=$WORKLOAD \
base_output_directory=$OUTPUT_PATH \
hf_access_token=$HF_TOKEN"
```

The overview of the demo script ~/maxtext/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py` is as follows:
Expand Down
1 change: 1 addition & 0 deletions src/MaxText/configs/rl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ value_proj: 'offload'
checkpoint_storage_use_ocdbt: False # For Pathways
checkpoint_storage_use_zarr3: False # For Pathways
use_pathways: True
log_period: 20

# ====== Debugging ======
debug:
Expand Down
43 changes: 12 additions & 31 deletions src/MaxText/examples/install_tunix_vllm_requirement.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,15 @@
set -e
set -x

python -m ensurepip --default-pip

pip uninstall -y jax jaxlib libtpu

pip install aiohttp==3.12.15

# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.
pip install keyring keyrings.google-artifactregistry-auth

# Install vLLM for Jax and TPUs from the artifact registry
VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \
--index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
--extra-index-url https://pypi.org/simple/ \
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu

# Install tpu-commons from the artifact registry
pip install --no-cache-dir --pre \
--index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
--extra-index-url https://pypi.org/simple/ \
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
tpu-commons==0.1.2

pip install numba==0.61.2
uv pip uninstall -y jax jaxlib libtpu

uv pip install aiohttp==3.12.15

# Install vLLM for Jax and TPUs
uv pip install vllm-tpu

uv pip install numba==0.61.2

uv pip install qwix==0.1.1

uv pip install flax==0.11.1
Loading