Skip to content

Commit 60ea9f3

Browse files
Merge pull request #2690 from AI-Hypercomputer:post_training_docker_fix
PiperOrigin-RevId: 832376901
2 parents c92d9d9 + f958ed7 commit 60ea9f3

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,18 @@ ENV MODE=$MODE
2020

2121
RUN echo "Installing Post-Training dependencies (vLLM, tpu-common, tunix) with MODE=${MODE}"
2222

23-
2423
# Uninstall existing jax to avoid conflicts
25-
RUN uv pip uninstall -y jax jaxlib libtpu
24+
RUN pip uninstall -y jax jaxlib libtpu
2625

27-
RUN uv pip install aiohttp==3.12.15
26+
RUN pip install aiohttp==3.12.15
2827

29-
RUN uv pip install numba==0.61.2
28+
RUN pip install numba==0.61.2
3029

3130
# Install vLLM for Jax and TPUs
32-
RUN uv pip install vllm-tpu
31+
RUN pip install vllm-tpu
3332

3433
RUN if [ "$MODE" = "post-training-experimental" ]; then \
35-
uv pip uninstall -y jax jaxlib libtpu && \
36-
uv pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
37-
uv pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
34+
pip uninstall -y jax jaxlib libtpu && \
35+
pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
36+
pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
3837
fi

dependencies/scripts/docker_build_dependency_image.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then
150150
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_db_dependencies.Dockerfile' \
151151
-t ${LOCAL_IMAGE_NAME} .
152152
elif [[ ${INSTALL_POST_TRAINING} -eq 1 && ${DEVICE} == "tpu" ]]; then
153-
echo "Installing MaxText stable mode dependencies for GRPO"
153+
echo "Installing MaxText stable mode dependencies for post-training"
154154
docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION \
155155
--build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE \
156156
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_dependencies.Dockerfile' \

0 commit comments

Comments
 (0)