Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions docs/tutorials/grpo_with_pathways.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,4 @@ python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
--hf_access_token=$HF_TOKEN"
```

The overview of the demo script ~/maxtext/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py` is as follows:

1. We load a policy model and a reference model. Both are copies of `Llama3.1-70b-Instruct`.
2. Evaluate the policy model's performance on GSM8K math reasoning benchmark.
3. Train the policy model using GRPO with potentially different meshes for trainer and rollout depending on the parameters `TRAINER_DEVICES_FRACTION` and `SAMPLER_DEVICES_FRACTION`. If we set both of these to `1.0`, the entire (same) mesh will be used for both trainer and rollout. If we set say `TRAINER_DEVICES_FRACTION=0.5` and `SAMPLER_DEVICES_FRACTION=0.5`, the first half of the devices will be used for trainer and the second half will be used for rollout
4. Evaluate the policy model's performance on GSM8K math reasoning benchmark after the post-training with GRPO.
For an interactive walkthrough, open `src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb`. The notebook now delegates to `rl_train`, so you can reuse the same configuration flags shown above (including `trainer_devices_fraction` and `sampler_devices_fraction`) when scaling to multi-host Pathways or to larger checkpoints such as Llama3.1-70B.
232 changes: 168 additions & 64 deletions src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,17 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# GRPO Llama3.1-8B Demo: Direct Function Call\n",
"# GRPO Llama3.1-8B Demo\n",
"\n",
"This notebook demonstrates GRPO training by directly calling the `rl_train` function from `rl_trainer.py`.\n",
"This notebook demonstrates GRPO (Group Relative Policy Optimization) training using the unified `rl_train` function.\n",
"\n",
"## What is GRPO?\n",
"\n",
"GRPO (Group Relative Policy Optimization) is an RL algorithm that enhances reasoning abilities of LLMs by:\n",
"GRPO is an RL algorithm that enhances reasoning abilities of LLMs by:\n",
"1. Generating multiple responses for each prompt\n",
"2. Evaluating responses using reward models \n",
"3. Calculating relative advantages to update the policy\n",
"\n",
"\n",
"This notebook imports and calls the `rl_train` function \n",
"\n",
"## Hardware Requirements\n",
"\n",
"- Single host TPUVM (v6e-8/v5p-8) or multi-host with Pathways\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we run this colab on mulit-host pathways?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the section

Expand All @@ -40,8 +37,8 @@
"outputs": [],
"source": [
"# Clone MaxText repository\n",
"!git clone https://github.com/AI-Hypercomputer/maxtext.git\n",
"%cd maxtext"
"!git clone https://github.com/AI-Hypercomputer/maxtext\n",
"%cd maxtext/src"
]
},
{
Expand All @@ -50,19 +47,57 @@
"metadata": {},
"outputs": [],
"source": [
"# Install dependencies\n",
"!chmod +x setup.sh\n",
"!./setup.sh\n",
"!bash tools/setup/setup.sh\n",
"%pip uninstall -y jax jaxlib libtpu\n",
"\n",
"%pip install aiohttp==3.12.15\n",
"\n",
"# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.\n",
"%pip install keyring keyrings.google-artifactregistry-auth\n",
"\n",
"# Install vLLM for Jax and TPUs from the artifact registry\n",
"!VLLM_TARGET_DEVICE=\"tpu\" pip install --no-cache-dir --pre \\\n",
" --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \\\n",
" --extra-index-url https://pypi.org/simple/ \\\n",
" --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \\\n",
" --extra-index-url https://download.pytorch.org/whl/nightly/cpu \\\n",
" --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \\\n",
" --find-links https://storage.googleapis.com/libtpu-wheels/index.html \\\n",
" --find-links https://storage.googleapis.com/libtpu-releases/index.html \\\n",
" --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \\\n",
" --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \\\n",
" vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu\n",
"\n",
"# Install GRPO-specific dependencies\n",
"!./src/MaxText/examples/install_tunix_vllm_requirement.sh\n",
"# Install tpu-commons from the artifact registry\n",
"%pip install --no-cache-dir --pre \\\n",
" --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \\\n",
" --extra-index-url https://pypi.org/simple/ \\\n",
" --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \\\n",
" --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \\\n",
" tpu-commons==0.1.2\n",
"\n",
"%pip install numba==0.61.2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# Install additional requirements\n",
"%pip install --force-reinstall numpy==2.1.2\n",
"%pip install nest_asyncio\n",
"\n",
"import nest_asyncio\n",
"nest_asyncio.apply() # Fix for Colab event loop"
"nest_asyncio.apply() # Fix for Colab event loop\n",
"\n",
"%cd maxtext/src/\n",
"\n",
"#Fix nnx problems\n",
"!pip uninstall flax \n",
"!pip uninstall qwix\n",
"!pip install flax \n",
"!pip install qwix"
]
},
{
Expand All @@ -71,9 +106,27 @@
"source": [
"## Configuration\n",
"\n",
"Set up the training parameters:"
"Set up the training parameters. Defaults are hardcoded for Llama3.1-8B:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Multi-host Pathways\n",
"\n",
"To run this demo on a multi-host Pathways setup:\n",
"- Set `use_pathways=True` in `rl.yml` (enabled by default).\n",
"- Override `trainer_devices_fraction` and `sampler_devices_fraction` in `config_argv` to split the mesh across hosts.\n",
"- Launch the Colab kernel on the controller host and export Pathways runtime variables (for example `JAX_PLATFORMS=proxy` and `ENABLE_PATHWAYS_PERSISTENCE=1`) before running training.\n",
"- Update `chips_per_vm` to match your slice topology; Pathways will shard trainer and rollout workers automatically.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -82,20 +135,38 @@
"source": [
"# Configuration for GRPO training\n",
"import os\n",
"import MaxText\n",
"\n",
"# Set up paths (adjust if needed)\n",
"MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n",
"RUN_NAME=\"grpo_test\"\n",
"# Hardcoded defaults for Llama3.1-8B\n",
"MODEL_NAME = \"llama3.1-8b\"\n",
"HF_REPO_ID = \"meta-llama/Llama-3.1-8B-Instruct\"\n",
"CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/examples/chat_templates/gsm8k_rl.json\"\n",
"LOSS_ALGO=\"gspo-token\"\n",
"\n",
"# Set up paths\n",
"MAXTEXT_REPO_ROOT = os.path.expanduser(\"~\") + \"/maxtext\"\n",
"print(f\"MaxText Home directory: {MAXTEXT_REPO_ROOT}\")\n",
"# Required: Set these before running\n",
"MODEL_CHECKPOINT_PATH = \"\" # Update this!\n",
"OUTPUT_DIRECTORY = \"/tmp/gpo_output\" # Update this!\n",
"HF_TOKEN = \"\" # Set HF_TOKEN environment variable\n",
"\n",
"# Training configuration\n",
"MODEL_CHECKPOINT_PATH = \"gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items\"\n",
"OUTPUT_DIRECTORY = \"/tmp/grpo_output\"\n",
"# Optional: Override training parameters\n",
"STEPS = 10 # Reduced for demo purposes\n",
"HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"your_hf_token_here\")\n",
"PER_DEVICE_BATCH_SIZE = 1\n",
"LEARNING_RATE = 3e-6\n",
"NUM_GENERATIONS = 2\n",
"GRPO_BETA = 0.08\n",
"GRPO_EPSILON = 0.2\n",
"CHIPS_PER_VM = 1\n",
"\n",
"print(f\"Model checkpoint: {MODEL_CHECKPOINT_PATH}\")\n",
"print(f\"Output directory: {OUTPUT_DIRECTORY}\")\n",
"print(f\"Training steps: {STEPS}\")"
"print(f\"📁 MaxText Home: {MAXTEXT_REPO_ROOT}\")\n",
"print(f\"🤖 Model: {MODEL_NAME}\")\n",
"print(f\"📦 Checkpoint: {MODEL_CHECKPOINT_PATH}\")\n",
"print(f\"💾 Output: {OUTPUT_DIRECTORY}\")\n",
"print(f\"🔑 HF Token: {'✅ Set' if HF_TOKEN else '❌ Missing - set HF_TOKEN env var'}\")\n",
"print(f\"📊 Steps: {STEPS}\")\n",
"print(f\"Loss Algorithm : {LOSS_ALGO}\")"
]
},
{
Expand All @@ -104,24 +175,33 @@
"metadata": {},
"outputs": [],
"source": [
"# Import GRPO training function directly\n",
"import sys\n",
"# Import required modules\n",
"import os\n",
"import sys\n",
"from pathlib import Path\n",
"\n",
"# Add MaxText to Python path\n",
"maxtext_path = Path(MAXTEXT_REPO_ROOT) / \"src\" / \"MaxText\"\n",
"maxtext_path = Path(MAXTEXT_REPO_ROOT) \n",
"sys.path.insert(0, str(maxtext_path))\n",
"\n",
"# Import required modules\n",
"from MaxText import pyconfig\n",
"from MaxText.train_rl import rl_train\n",
"from MaxText import pyconfig, max_utils\n",
"from MaxText.rl.train_rl import rl_train\n",
"import jax\n",
"\n",
"print(\"✅ Successfully imported GRPO training function\")\n",
"print(f\"📁 MaxText path: {maxtext_path}\")\n",
"print(\"\\n\" + \"=\"*80)\n",
"print(\"Starting GRPO Training...\")\n",
"print(\"=\"*80)"
"# Initialize JAX and Pathways\n",
"import pathwaysutils\n",
"pathwaysutils.initialize()\n",
"jax.config.update(\"jax_default_prng_impl\", \"unsafe_rbg\")\n",
"os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"0\"\n",
"os.environ[\"SKIP_JAX_PRECOMPILE\"] = \"1\" # Faster startup for vLLM\n",
"\n",
"if \"xla_tpu_spmd_rng_bit_generator_unsafe\" not in os.environ.get(\"LIBTPU_INIT_ARGS\", \"\"):\n",
" os.environ[\"LIBTPU_INIT_ARGS\"] = (\n",
" os.environ.get(\"LIBTPU_INIT_ARGS\", \"\") + \" --xla_tpu_spmd_rng_bit_generator_unsafe=true\"\n",
" )\n",
"\n",
"print(\"✅ Successfully imported modules\")\n",
"print(f\"📁 MaxText path: {maxtext_path}\")"
]
},
{
Expand All @@ -131,28 +211,39 @@
"outputs": [],
"source": [
"# Build configuration for GRPO training\n",
"config_file = os.path.join(MAXTEXT_REPO_ROOT, \"configs/rl.yml\")\n",
"\n",
"# Verify chat template exists\n",
"if not os.path.exists(os.path.join(MAXTEXT_REPO_ROOT, CHAT_TEMPLATE_PATH)):\n",
" raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n",
"\n",
"# Build argv list for pyconfig.initialize()\n",
"config_argv = [\n",
" \"\", # Placeholder for argv[0]\n",
" \"src/MaxText/configs/grpo.yml\", # Base config\n",
" f\"model_name=llama3.1-8b\",\n",
" f\"tokenizer_path=meta-llama/Llama-3.1-8B-Instruct\",\n",
" \"\", # argv[0] placeholder\n",
" config_file,\n",
" f\"model_name={MODEL_NAME}\",\n",
" f\"tokenizer_path={HF_REPO_ID}\",\n",
" f\"run_name={RUN_NAME}\",\n",
" f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n",
" f\"load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n",
" f\"base_output_directory={OUTPUT_DIRECTORY}\",\n",
" f\"hf_access_token={HF_TOKEN}\",\n",
" f\"steps={STEPS}\",\n",
" \"per_device_batch_size=1\",\n",
" \"learning_rate=3e-6\",\n",
" \"num_generations=2\",\n",
" \"grpo_beta=0.08\",\n",
" \"grpo_epsilon=0.2\",\n",
" \"chips_per_vm=4\"\n",
" f\"per_device_batch_size={PER_DEVICE_BATCH_SIZE}\",\n",
" f\"learning_rate={LEARNING_RATE}\",\n",
" f\"num_generations={NUM_GENERATIONS}\",\n",
" f\"grpo_beta={GRPO_BETA}\",\n",
" f\"grpo_epsilon={GRPO_EPSILON}\",\n",
" f\"chips_per_vm={CHIPS_PER_VM}\",\n",
" f\"loss_algo={LOSS_ALGO}\"\n",
"]\n",
"\n",
"# Create configuration object\n",
"config = pyconfig.Config()\n",
"config.parse_flags(config_argv)\n",
"# Initialize configuration\n",
"print(f\"🔧 Initializing configuration from: {config_file}\")\n",
"config = pyconfig.initialize(config_argv)\n",
"max_utils.print_system_information()\n",
"\n",
"print(\"✅ Configuration created successfully\")\n",
"print(\"\\n✅ Configuration initialized successfully\")\n",
"print(f\"📊 Training steps: {config.steps}\")\n",
"print(f\"📁 Output directory: {config.base_output_directory}\")\n",
"print(f\"🤖 Model: {config.model_name}\")"
Expand All @@ -164,33 +255,46 @@
"metadata": {},
"outputs": [],
"source": [
"# Execute GRPO training directly\n",
"# Execute GRPO/GSPO training\n",
"print(\"\\n\" + \"=\"*80)\n",
"print(\"🚀 Starting Training...\")\n",
"print(\"=\"*80)\n",
"print(1)\n",
"try:\n",
" # Call the rl_train function\n",
" grpo_trainer, rl_cluster = rl_train(config)\n",
" # Call the rl_train function (it handles everything internally)\n",
" rl_train(config)\n",
" \n",
" print(\"\\n\" + \"=\"*80)\n",
" print(\"✅ GRPO Training Completed Successfully!\")\n",
" print(\"✅ Training Completed Successfully!\")\n",
" print(\"=\"*80)\n",
" print(f\"📁 Checkpoints and logs saved to: {config.base_output_directory}\")\n",
" print(f\"🎯 Final model ready for inference!\")\n",
" print(f\"📁 Checkpoints saved to: {config.checkpoint_dir}\")\n",
" print(f\"📊 TensorBoard logs: {config.tensorboard_dir}\")\n",
" print(f\"🎯 Model ready for inference!\")\n",
" \n",
"except Exception as e:\n",
" print(\"\\n\" + \"=\"*80)\n",
" print(\"❌ GRPO Training Failed!\")\n",
" print(\"❌Training Failed!\")\n",
" print(\"=\"*80)\n",
" print(f\"Error: {str(e)}\")\n",
" print(\"\\nPlease check the error message and try again.\")"
" import traceback\n",
" traceback.print_exc()\n",
" print(\"\\n💡 Common issues:\")\n",
" print(\" - Check that MODEL_CHECKPOINT_PATH points to a valid checkpoint\")\n",
" print(\" - Ensure HF_TOKEN environment variable is set\")\n",
" print(\" - Verify OUTPUT_DIRECTORY is writable\")\n",
" print(\" - Check hardware requirements (TPU/GPU availability)\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 📚 **Learn More**\n",
"- See `src/MaxText/examples/grpo_runner.py` for CLI usage\n",
"- Check `src/MaxText/configs/grpo.yml` for configuration options\n",
"- Read `src/MaxText/examples/README.md` for more examples"
"## 📚 Learn More\n",
"\n",
"- **CLI Usage**: Run `python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml --model_name=llama3.1-8b ...`\n",
"- **Configuration**: See `src/MaxText/configs/rl.yml` for all available options\n",
"- **Documentation**: Check `src/MaxText/rl/train_rl.py` for the `rl_train` function implementation\n",
"- **Examples**: See other examples in `src/MaxText/examples/`"
]
}
],
Expand Down
Loading