diff --git a/docs/tutorials/grpo_with_pathways.md b/docs/tutorials/grpo_with_pathways.md index c2ce421c7..895dbdb51 100644 --- a/docs/tutorials/grpo_with_pathways.md +++ b/docs/tutorials/grpo_with_pathways.md @@ -69,9 +69,4 @@ python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ --hf_access_token=$HF_TOKEN" ``` -The overview of the demo script ~/maxtext/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py` is as follows: - -1. We load a policy model and a reference model. Both are copies of `Llama3.1-70b-Instruct`. -2. Evaluate the policy model's performance on GSM8K math reasoning benchmark. -3. Train the policy model using GRPO with potentially different meshes for trainer and rollout depending on the parameters `TRAINER_DEVICES_FRACTION` and `SAMPLER_DEVICES_FRACTION`. If we set both of these to `1.0`, the entire (same) mesh will be used for both trainer and rollout. If we set say `TRAINER_DEVICES_FRACTION=0.5` and `SAMPLER_DEVICES_FRACTION=0.5`, the first half of the devices will be used for trainer and the second half will be used for rollout -4. Evaluate the policy model's performance on GSM8K math reasoning benchmark after the post-training with GRPO. +For an interactive walkthrough, open `src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb`. The notebook now delegates to `rl_train`, so you can reuse the same configuration flags shown above (including `trainer_devices_fraction` and `sampler_devices_fraction`) when scaling to multi-host Pathways or to larger checkpoints such as Llama3.1-70B. diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb index e24cbba99..e17221eaf 100644 --- a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb +++ b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb @@ -4,20 +4,17 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# GRPO Llama3.1-8B Demo: Direct Function Call\n", + "# GRPO Llama3.1-8B Demo\n", "\n", - "This notebook demonstrates GRPO training by directly calling the `rl_train` function from `rl_trainer.py`.\n", + "This notebook demonstrates GRPO (Group Relative Policy Optimization) training using the unified `rl_train` function.\n", "\n", "## What is GRPO?\n", "\n", - "GRPO (Group Relative Policy Optimization) is an RL algorithm that enhances reasoning abilities of LLMs by:\n", + "GRPO is an RL algorithm that enhances reasoning abilities of LLMs by:\n", "1. Generating multiple responses for each prompt\n", "2. Evaluating responses using reward models \n", "3. Calculating relative advantages to update the policy\n", "\n", - "\n", - "This notebook imports and calls the `rl_train` function \n", - "\n", "## Hardware Requirements\n", "\n", "- Single host TPUVM (v6e-8/v5p-8) or multi-host with Pathways\n", @@ -40,8 +37,8 @@ "outputs": [], "source": [ "# Clone MaxText repository\n", - "!git clone https://github.com/AI-Hypercomputer/maxtext.git\n", - "%cd maxtext" + "!git clone https://github.com/AI-Hypercomputer/maxtext\n", + "%cd maxtext/src" ] }, { @@ -50,19 +47,57 @@ "metadata": {}, "outputs": [], "source": [ - "# Install dependencies\n", - "!chmod +x setup.sh\n", - "!./setup.sh\n", + "!bash tools/setup/setup.sh\n", + "%pip uninstall -y jax jaxlib libtpu\n", + "\n", + "%pip install aiohttp==3.12.15\n", + "\n", + "# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.\n", + "%pip install keyring keyrings.google-artifactregistry-auth\n", + "\n", + "# Install vLLM for Jax and TPUs from the artifact registry\n", + "!VLLM_TARGET_DEVICE=\"tpu\" pip install --no-cache-dir --pre \\\n", + " --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \\\n", + " --extra-index-url https://pypi.org/simple/ \\\n", + " --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \\\n", + " --extra-index-url https://download.pytorch.org/whl/nightly/cpu \\\n", + " --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \\\n", + " --find-links https://storage.googleapis.com/libtpu-wheels/index.html \\\n", + " --find-links https://storage.googleapis.com/libtpu-releases/index.html \\\n", + " --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \\\n", + " --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \\\n", + " vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu\n", "\n", - "# Install GRPO-specific dependencies\n", - "!./src/MaxText/examples/install_tunix_vllm_requirement.sh\n", + "# Install tpu-commons from the artifact registry\n", + "%pip install --no-cache-dir --pre \\\n", + " --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \\\n", + " --extra-index-url https://pypi.org/simple/ \\\n", + " --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \\\n", + " --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \\\n", + " tpu-commons==0.1.2\n", + "\n", + "%pip install numba==0.61.2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "\n", - "# Install additional requirements\n", - "%pip install --force-reinstall numpy==2.1.2\n", "%pip install nest_asyncio\n", "\n", "import nest_asyncio\n", - "nest_asyncio.apply() # Fix for Colab event loop" + "nest_asyncio.apply() # Fix for Colab event loop\n", + "\n", + "%cd maxtext/src/\n", + "\n", + "#Fix nnx problems\n", + "!pip uninstall flax \n", + "!pip uninstall qwix\n", + "!pip install flax \n", + "!pip install qwix" ] }, { @@ -71,9 +106,27 @@ "source": [ "## Configuration\n", "\n", - "Set up the training parameters:" + "Set up the training parameters. Defaults are hardcoded for Llama3.1-8B:" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Multi-host Pathways\n", + "\n", + "To run this demo on a multi-host Pathways setup:\n", + "- Set `use_pathways=True` in `rl.yml` (enabled by default).\n", + "- Override `trainer_devices_fraction` and `sampler_devices_fraction` in `config_argv` to split the mesh across hosts.\n", + "- Launch the Colab kernel on the controller host and export Pathways runtime variables (for example `JAX_PLATFORMS=proxy` and `ENABLE_PATHWAYS_PERSISTENCE=1`) before running training.\n", + "- Update `chips_per_vm` to match your slice topology; Pathways will shard trainer and rollout workers automatically.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -82,20 +135,38 @@ "source": [ "# Configuration for GRPO training\n", "import os\n", + "import MaxText\n", + "\n", + "# Set up paths (adjust if needed)\n", + "MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n", + "RUN_NAME=\"grpo_test\"\n", + "# Hardcoded defaults for Llama3.1-8B\n", + "MODEL_NAME = \"llama3.1-8b\"\n", + "HF_REPO_ID = \"meta-llama/Llama-3.1-8B-Instruct\"\n", + "CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/examples/chat_templates/gsm8k_rl.json\"\n", + "LOSS_ALGO=\"gspo-token\"\n", "\n", - "# Set up paths\n", - "MAXTEXT_REPO_ROOT = os.path.expanduser(\"~\") + \"/maxtext\"\n", - "print(f\"MaxText Home directory: {MAXTEXT_REPO_ROOT}\")\n", + "# Required: Set these before running\n", + "MODEL_CHECKPOINT_PATH = \"\" # Update this!\n", + "OUTPUT_DIRECTORY = \"/tmp/gpo_output\" # Update this!\n", + "HF_TOKEN = \"\" # Set HF_TOKEN environment variable\n", "\n", - "# Training configuration\n", - "MODEL_CHECKPOINT_PATH = \"gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items\"\n", - "OUTPUT_DIRECTORY = \"/tmp/grpo_output\"\n", + "# Optional: Override training parameters\n", "STEPS = 10 # Reduced for demo purposes\n", - "HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"your_hf_token_here\")\n", + "PER_DEVICE_BATCH_SIZE = 1\n", + "LEARNING_RATE = 3e-6\n", + "NUM_GENERATIONS = 2\n", + "GRPO_BETA = 0.08\n", + "GRPO_EPSILON = 0.2\n", + "CHIPS_PER_VM = 1\n", "\n", - "print(f\"Model checkpoint: {MODEL_CHECKPOINT_PATH}\")\n", - "print(f\"Output directory: {OUTPUT_DIRECTORY}\")\n", - "print(f\"Training steps: {STEPS}\")" + "print(f\"πŸ“ MaxText Home: {MAXTEXT_REPO_ROOT}\")\n", + "print(f\"πŸ€– Model: {MODEL_NAME}\")\n", + "print(f\"πŸ“¦ Checkpoint: {MODEL_CHECKPOINT_PATH}\")\n", + "print(f\"πŸ’Ύ Output: {OUTPUT_DIRECTORY}\")\n", + "print(f\"πŸ”‘ HF Token: {'βœ… Set' if HF_TOKEN else '❌ Missing - set HF_TOKEN env var'}\")\n", + "print(f\"πŸ“Š Steps: {STEPS}\")\n", + "print(f\"Loss Algorithm : {LOSS_ALGO}\")" ] }, { @@ -104,24 +175,33 @@ "metadata": {}, "outputs": [], "source": [ - "# Import GRPO training function directly\n", - "import sys\n", + "# Import required modules\n", "import os\n", + "import sys\n", "from pathlib import Path\n", "\n", "# Add MaxText to Python path\n", - "maxtext_path = Path(MAXTEXT_REPO_ROOT) / \"src\" / \"MaxText\"\n", + "maxtext_path = Path(MAXTEXT_REPO_ROOT) \n", "sys.path.insert(0, str(maxtext_path))\n", "\n", - "# Import required modules\n", - "from MaxText import pyconfig\n", - "from MaxText.train_rl import rl_train\n", + "from MaxText import pyconfig, max_utils\n", + "from MaxText.rl.train_rl import rl_train\n", + "import jax\n", "\n", - "print(\"βœ… Successfully imported GRPO training function\")\n", - "print(f\"πŸ“ MaxText path: {maxtext_path}\")\n", - "print(\"\\n\" + \"=\"*80)\n", - "print(\"Starting GRPO Training...\")\n", - "print(\"=\"*80)" + "# Initialize JAX and Pathways\n", + "import pathwaysutils\n", + "pathwaysutils.initialize()\n", + "jax.config.update(\"jax_default_prng_impl\", \"unsafe_rbg\")\n", + "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"0\"\n", + "os.environ[\"SKIP_JAX_PRECOMPILE\"] = \"1\" # Faster startup for vLLM\n", + "\n", + "if \"xla_tpu_spmd_rng_bit_generator_unsafe\" not in os.environ.get(\"LIBTPU_INIT_ARGS\", \"\"):\n", + " os.environ[\"LIBTPU_INIT_ARGS\"] = (\n", + " os.environ.get(\"LIBTPU_INIT_ARGS\", \"\") + \" --xla_tpu_spmd_rng_bit_generator_unsafe=true\"\n", + " )\n", + "\n", + "print(\"βœ… Successfully imported modules\")\n", + "print(f\"πŸ“ MaxText path: {maxtext_path}\")" ] }, { @@ -131,28 +211,39 @@ "outputs": [], "source": [ "# Build configuration for GRPO training\n", + "config_file = os.path.join(MAXTEXT_REPO_ROOT, \"configs/rl.yml\")\n", + "\n", + "# Verify chat template exists\n", + "if not os.path.exists(os.path.join(MAXTEXT_REPO_ROOT, CHAT_TEMPLATE_PATH)):\n", + " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n", + "\n", + "# Build argv list for pyconfig.initialize()\n", "config_argv = [\n", - " \"\", # Placeholder for argv[0]\n", - " \"src/MaxText/configs/grpo.yml\", # Base config\n", - " f\"model_name=llama3.1-8b\",\n", - " f\"tokenizer_path=meta-llama/Llama-3.1-8B-Instruct\",\n", + " \"\", # argv[0] placeholder\n", + " config_file,\n", + " f\"model_name={MODEL_NAME}\",\n", + " f\"tokenizer_path={HF_REPO_ID}\",\n", + " f\"run_name={RUN_NAME}\",\n", + " f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n", " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n", " f\"base_output_directory={OUTPUT_DIRECTORY}\",\n", " f\"hf_access_token={HF_TOKEN}\",\n", " f\"steps={STEPS}\",\n", - " \"per_device_batch_size=1\",\n", - " \"learning_rate=3e-6\",\n", - " \"num_generations=2\",\n", - " \"grpo_beta=0.08\",\n", - " \"grpo_epsilon=0.2\",\n", - " \"chips_per_vm=4\"\n", + " f\"per_device_batch_size={PER_DEVICE_BATCH_SIZE}\",\n", + " f\"learning_rate={LEARNING_RATE}\",\n", + " f\"num_generations={NUM_GENERATIONS}\",\n", + " f\"grpo_beta={GRPO_BETA}\",\n", + " f\"grpo_epsilon={GRPO_EPSILON}\",\n", + " f\"chips_per_vm={CHIPS_PER_VM}\",\n", + " f\"loss_algo={LOSS_ALGO}\"\n", "]\n", "\n", - "# Create configuration object\n", - "config = pyconfig.Config()\n", - "config.parse_flags(config_argv)\n", + "# Initialize configuration\n", + "print(f\"πŸ”§ Initializing configuration from: {config_file}\")\n", + "config = pyconfig.initialize(config_argv)\n", + "max_utils.print_system_information()\n", "\n", - "print(\"βœ… Configuration created successfully\")\n", + "print(\"\\nβœ… Configuration initialized successfully\")\n", "print(f\"πŸ“Š Training steps: {config.steps}\")\n", "print(f\"πŸ“ Output directory: {config.base_output_directory}\")\n", "print(f\"πŸ€– Model: {config.model_name}\")" @@ -164,33 +255,46 @@ "metadata": {}, "outputs": [], "source": [ - "# Execute GRPO training directly\n", + "# Execute GRPO/GSPO training\n", + "print(\"\\n\" + \"=\"*80)\n", + "print(\"πŸš€ Starting Training...\")\n", + "print(\"=\"*80)\n", + "print(1)\n", "try:\n", - " # Call the rl_train function\n", - " grpo_trainer, rl_cluster = rl_train(config)\n", + " # Call the rl_train function (it handles everything internally)\n", + " rl_train(config)\n", " \n", " print(\"\\n\" + \"=\"*80)\n", - " print(\"βœ… GRPO Training Completed Successfully!\")\n", + " print(\"βœ… Training Completed Successfully!\")\n", " print(\"=\"*80)\n", - " print(f\"πŸ“ Checkpoints and logs saved to: {config.base_output_directory}\")\n", - " print(f\"🎯 Final model ready for inference!\")\n", + " print(f\"πŸ“ Checkpoints saved to: {config.checkpoint_dir}\")\n", + " print(f\"πŸ“Š TensorBoard logs: {config.tensorboard_dir}\")\n", + " print(f\"🎯 Model ready for inference!\")\n", " \n", "except Exception as e:\n", " print(\"\\n\" + \"=\"*80)\n", - " print(\"❌ GRPO Training Failed!\")\n", + " print(\"❌Training Failed!\")\n", " print(\"=\"*80)\n", " print(f\"Error: {str(e)}\")\n", - " print(\"\\nPlease check the error message and try again.\")" + " import traceback\n", + " traceback.print_exc()\n", + " print(\"\\nπŸ’‘ Common issues:\")\n", + " print(\" - Check that MODEL_CHECKPOINT_PATH points to a valid checkpoint\")\n", + " print(\" - Ensure HF_TOKEN environment variable is set\")\n", + " print(\" - Verify OUTPUT_DIRECTORY is writable\")\n", + " print(\" - Check hardware requirements (TPU/GPU availability)\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### πŸ“š **Learn More**\n", - "- See `src/MaxText/examples/grpo_runner.py` for CLI usage\n", - "- Check `src/MaxText/configs/grpo.yml` for configuration options\n", - "- Read `src/MaxText/examples/README.md` for more examples" + "## πŸ“š Learn More\n", + "\n", + "- **CLI Usage**: Run `python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml --model_name=llama3.1-8b ...`\n", + "- **Configuration**: See `src/MaxText/configs/rl.yml` for all available options\n", + "- **Documentation**: Check `src/MaxText/rl/train_rl.py` for the `rl_train` function implementation\n", + "- **Examples**: See other examples in `src/MaxText/examples/`" ] } ],