From 492c674b6a81faa487b674163939627a48238f4f Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Thu, 6 Nov 2025 23:34:08 +0400 Subject: [PATCH 01/10] Use train_rl Signed-off-by: Vladimir Suvorov --- .../examples/grpo_llama3_1_8b_demo.ipynb | 156 ++++++++++++------ 1 file changed, 106 insertions(+), 50 deletions(-) diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb index e24cbba99..567e7a8d4 100644 --- a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb +++ b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb @@ -4,20 +4,17 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# GRPO Llama3.1-8B Demo: Direct Function Call\n", + "# GRPO Llama3.1-8B Demo\n", "\n", - "This notebook demonstrates GRPO training by directly calling the `rl_train` function from `rl_trainer.py`.\n", + "This notebook demonstrates GRPO (Group Relative Policy Optimization) training using the unified `rl_train` function.\n", "\n", "## What is GRPO?\n", "\n", - "GRPO (Group Relative Policy Optimization) is an RL algorithm that enhances reasoning abilities of LLMs by:\n", + "GRPO is an RL algorithm that enhances reasoning abilities of LLMs by:\n", "1. Generating multiple responses for each prompt\n", "2. Evaluating responses using reward models \n", "3. Calculating relative advantages to update the policy\n", "\n", - "\n", - "This notebook imports and calls the `rl_train` function \n", - "\n", "## Hardware Requirements\n", "\n", "- Single host TPUVM (v6e-8/v5p-8) or multi-host with Pathways\n", @@ -71,7 +68,7 @@ "source": [ "## Configuration\n", "\n", - "Set up the training parameters:" + "Set up the training parameters. Defaults are hardcoded for Llama3.1-8B:" ] }, { @@ -83,19 +80,35 @@ "# Configuration for GRPO training\n", "import os\n", "\n", - "# Set up paths\n", + "# Set up paths (adjust if needed)\n", "MAXTEXT_REPO_ROOT = os.path.expanduser(\"~\") + \"/maxtext\"\n", - "print(f\"MaxText Home directory: {MAXTEXT_REPO_ROOT}\")\n", "\n", - "# Training configuration\n", - "MODEL_CHECKPOINT_PATH = \"gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items\"\n", - "OUTPUT_DIRECTORY = \"/tmp/grpo_output\"\n", + "# Hardcoded defaults for Llama3.1-8B\n", + "MODEL_NAME = \"llama3.1-8b\"\n", + "TOKENIZER_PATH = \"meta-llama/Llama-3.1-8B-Instruct\"\n", + "HF_MODEL_NAME = \"meta-llama/Llama-3.1-8B-Instruct\"\n", + "CHAT_TEMPLATE_PATH = \"src/MaxText/examples/chat_templates/gsm8k_rl.json\"\n", + "\n", + "# Required: Set these before running\n", + "MODEL_CHECKPOINT_PATH = \"gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items\" # Update this!\n", + "OUTPUT_DIRECTORY = \"/tmp/grpo_output\" # Update this!\n", + "HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\") # Set HF_TOKEN environment variable\n", + "\n", + "# Optional: Override training parameters\n", "STEPS = 10 # Reduced for demo purposes\n", - "HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"your_hf_token_here\")\n", + "PER_DEVICE_BATCH_SIZE = 1\n", + "LEARNING_RATE = 3e-6\n", + "NUM_GENERATIONS = 2\n", + "GRPO_BETA = 0.08\n", + "GRPO_EPSILON = 0.2\n", + "CHIPS_PER_VM = 4\n", "\n", - "print(f\"Model checkpoint: {MODEL_CHECKPOINT_PATH}\")\n", - "print(f\"Output directory: {OUTPUT_DIRECTORY}\")\n", - "print(f\"Training steps: {STEPS}\")" + "print(f\"πŸ“ MaxText Home: {MAXTEXT_REPO_ROOT}\")\n", + "print(f\"πŸ€– Model: {MODEL_NAME}\")\n", + "print(f\"πŸ“¦ Checkpoint: {MODEL_CHECKPOINT_PATH}\")\n", + "print(f\"πŸ’Ύ Output: {OUTPUT_DIRECTORY}\")\n", + "print(f\"πŸ”‘ HF Token: {'βœ… Set' if HF_TOKEN else '❌ Missing - set HF_TOKEN env var'}\")\n", + "print(f\"πŸ“Š Steps: {STEPS}\")" ] }, { @@ -104,24 +117,33 @@ "metadata": {}, "outputs": [], "source": [ - "# Import GRPO training function directly\n", - "import sys\n", + "# Import required modules\n", "import os\n", + "import sys\n", "from pathlib import Path\n", "\n", "# Add MaxText to Python path\n", "maxtext_path = Path(MAXTEXT_REPO_ROOT) / \"src\" / \"MaxText\"\n", "sys.path.insert(0, str(maxtext_path))\n", "\n", - "# Import required modules\n", - "from MaxText import pyconfig\n", - "from MaxText.train_rl import rl_train\n", + "from MaxText import pyconfig, max_utils\n", + "from MaxText.rl.train_rl import rl_train\n", + "import jax\n", "\n", - "print(\"βœ… Successfully imported GRPO training function\")\n", - "print(f\"πŸ“ MaxText path: {maxtext_path}\")\n", - "print(\"\\n\" + \"=\"*80)\n", - "print(\"Starting GRPO Training...\")\n", - "print(\"=\"*80)" + "# Initialize JAX and Pathways\n", + "import pathwaysutils\n", + "pathwaysutils.initialize()\n", + "jax.config.update(\"jax_default_prng_impl\", \"unsafe_rbg\")\n", + "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"0\"\n", + "os.environ[\"SKIP_JAX_PRECOMPILE\"] = \"1\" # Faster startup for vLLM\n", + "\n", + "if \"xla_tpu_spmd_rng_bit_generator_unsafe\" not in os.environ.get(\"LIBTPU_INIT_ARGS\", \"\"):\n", + " os.environ[\"LIBTPU_INIT_ARGS\"] = (\n", + " os.environ.get(\"LIBTPU_INIT_ARGS\", \"\") + \" --xla_tpu_spmd_rng_bit_generator_unsafe=true\"\n", + " )\n", + "\n", + "print(\"βœ… Successfully imported modules\")\n", + "print(f\"πŸ“ MaxText path: {maxtext_path}\")" ] }, { @@ -131,28 +153,49 @@ "outputs": [], "source": [ "# Build configuration for GRPO training\n", + "# Using rl.yml as the base config (not grpo.yml)\n", + "config_file = os.path.join(MAXTEXT_REPO_ROOT, \"src/MaxText/configs/rl.yml\")\n", + "\n", + "# Verify chat template exists\n", + "if not os.path.exists(os.path.join(MAXTEXT_REPO_ROOT, CHAT_TEMPLATE_PATH)):\n", + " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n", + "\n", + "# Build argv list for pyconfig.initialize()\n", "config_argv = [\n", - " \"\", # Placeholder for argv[0]\n", - " \"src/MaxText/configs/grpo.yml\", # Base config\n", - " f\"model_name=llama3.1-8b\",\n", - " f\"tokenizer_path=meta-llama/Llama-3.1-8B-Instruct\",\n", + " \"\", # argv[0] placeholder\n", + " config_file,\n", + " f\"model_name={MODEL_NAME}\",\n", + " f\"tokenizer_path={TOKENIZER_PATH}\",\n", + " f\"hf_model_name={HF_MODEL_NAME}\",\n", + " f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n", " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n", " f\"base_output_directory={OUTPUT_DIRECTORY}\",\n", " f\"hf_access_token={HF_TOKEN}\",\n", " f\"steps={STEPS}\",\n", - " \"per_device_batch_size=1\",\n", - " \"learning_rate=3e-6\",\n", - " \"num_generations=2\",\n", - " \"grpo_beta=0.08\",\n", - " \"grpo_epsilon=0.2\",\n", - " \"chips_per_vm=4\"\n", + " f\"per_device_batch_size={PER_DEVICE_BATCH_SIZE}\",\n", + " f\"learning_rate={LEARNING_RATE}\",\n", + " f\"num_generations={NUM_GENERATIONS}\",\n", + " f\"grpo_beta={GRPO_BETA}\",\n", + " f\"grpo_epsilon={GRPO_EPSILON}\",\n", + " f\"chips_per_vm={CHIPS_PER_VM}\",\n", "]\n", "\n", - "# Create configuration object\n", - "config = pyconfig.Config()\n", - "config.parse_flags(config_argv)\n", + "print(\"πŸ“‹ Configuration parameters:\")\n", + "for arg in config_argv[2:]: # Skip argv[0] and config_file\n", + " key, value = arg.split(\"=\", 1)\n", + " # Mask sensitive values\n", + " if \"token\" in key.lower() or \"password\" in key.lower():\n", + " display_value = \"***\" if value else \"not set\"\n", + " else:\n", + " display_value = value[:60] + \"...\" if len(value) > 60 else value\n", + " print(f\" {key}={display_value}\")\n", "\n", - "print(\"βœ… Configuration created successfully\")\n", + "# Initialize configuration\n", + "print(f\"\\nπŸ”§ Initializing configuration from: {config_file}\")\n", + "config = pyconfig.initialize(config_argv)\n", + "max_utils.print_system_information()\n", + "\n", + "print(\"\\nβœ… Configuration initialized successfully\")\n", "print(f\"πŸ“Š Training steps: {config.steps}\")\n", "print(f\"πŸ“ Output directory: {config.base_output_directory}\")\n", "print(f\"πŸ€– Model: {config.model_name}\")" @@ -164,33 +207,46 @@ "metadata": {}, "outputs": [], "source": [ - "# Execute GRPO training directly\n", + "# Execute GRPO training\n", + "print(\"\\n\" + \"=\"*80)\n", + "print(\"πŸš€ Starting GRPO Training...\")\n", + "print(\"=\"*80)\n", + "\n", "try:\n", - " # Call the rl_train function\n", - " grpo_trainer, rl_cluster = rl_train(config)\n", + " # Call the rl_train function (it handles everything internally)\n", + " rl_train(config)\n", " \n", " print(\"\\n\" + \"=\"*80)\n", " print(\"βœ… GRPO Training Completed Successfully!\")\n", " print(\"=\"*80)\n", - " print(f\"πŸ“ Checkpoints and logs saved to: {config.base_output_directory}\")\n", - " print(f\"🎯 Final model ready for inference!\")\n", + " print(f\"πŸ“ Checkpoints saved to: {config.checkpoint_dir}\")\n", + " print(f\"πŸ“Š TensorBoard logs: {config.tensorboard_dir}\")\n", + " print(f\"🎯 Model ready for inference!\")\n", " \n", "except Exception as e:\n", " print(\"\\n\" + \"=\"*80)\n", " print(\"❌ GRPO Training Failed!\")\n", " print(\"=\"*80)\n", " print(f\"Error: {str(e)}\")\n", - " print(\"\\nPlease check the error message and try again.\")" + " import traceback\n", + " traceback.print_exc()\n", + " print(\"\\nπŸ’‘ Common issues:\")\n", + " print(\" - Check that MODEL_CHECKPOINT_PATH points to a valid checkpoint\")\n", + " print(\" - Ensure HF_TOKEN environment variable is set\")\n", + " print(\" - Verify OUTPUT_DIRECTORY is writable\")\n", + " print(\" - Check hardware requirements (TPU/GPU availability)\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### πŸ“š **Learn More**\n", - "- See `src/MaxText/examples/grpo_runner.py` for CLI usage\n", - "- Check `src/MaxText/configs/grpo.yml` for configuration options\n", - "- Read `src/MaxText/examples/README.md` for more examples" + "## πŸ“š Learn More\n", + "\n", + "- **CLI Usage**: Run `python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml --model_name=llama3.1-8b ...`\n", + "- **Configuration**: See `src/MaxText/configs/rl.yml` for all available options\n", + "- **Documentation**: Check `src/MaxText/rl/train_rl.py` for the `rl_train` function implementation\n", + "- **Examples**: See other examples in `src/MaxText/examples/`" ] } ], From 5d5dcdd30c360bd5990b11a26e64913191a2dd67 Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Fri, 7 Nov 2025 21:06:21 +0400 Subject: [PATCH 02/10] Fixed lint and quantization Signed-off-by: Vladimir Suvorov --- src/MaxText/layerwise_quantization.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/src/MaxText/layerwise_quantization.py b/src/MaxText/layerwise_quantization.py index 4b5fe045d..0d739d7be 100644 --- a/src/MaxText/layerwise_quantization.py +++ b/src/MaxText/layerwise_quantization.py @@ -38,6 +38,7 @@ import jax.numpy as jnp from absl import app +from flax import nnx from flax.linen import partitioning as nn_partitioning @@ -93,9 +94,31 @@ def load_and_quantize(self, rng: None | PRNGKeyType = None) -> None: self.quant.quant_mode = quantizations.get_quant_mode("convert") + if rng is None: + rng = jax.random.PRNGKey(0) + + dense_init_rng, moe_init_rng = jax.random.split(rng) + dense_params_rng, dense_dropout_rng = jax.random.split(dense_init_rng) + moe_params_rng, moe_dropout_rng = jax.random.split(moe_init_rng) + + dense_layer_rngs = nnx.Rngs(params=dense_params_rng, dropout=dense_dropout_rng) + moe_layer_rngs = nnx.Rngs(params=moe_params_rng, dropout=moe_dropout_rng) + layers = [ - deepseek.DeepSeekDenseLayer(config, mesh=self._mesh, quant=self.quant), - deepseek.DeepSeekMoELayer(config, mesh=self._mesh, quant=self.quant), + deepseek.DeepSeekDenseLayer( + config, + model_mode=common_types.MODEL_MODE_TRAIN, + mesh=self._mesh, + rngs=dense_layer_rngs, + quant=self.quant, + ), + deepseek.DeepSeekMoELayer( + config, + model_mode=common_types.MODEL_MODE_TRAIN, + mesh=self._mesh, + rngs=moe_layer_rngs, + quant=self.quant, + ), ] layer_prefixes = ["dense_layers", "moe_layers"] num_moe_layers = config.num_decoder_layers - config.first_num_dense_layers From e512369c7be2c0f5c75acf5c0a44ac9b16f1fc86 Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Fri, 7 Nov 2025 22:48:26 +0400 Subject: [PATCH 03/10] Fix linter Signed-off-by: Vladimir Suvorov --- src/MaxText/layerwise_quantization.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/src/MaxText/layerwise_quantization.py b/src/MaxText/layerwise_quantization.py index 0d739d7be..e9838b402 100644 --- a/src/MaxText/layerwise_quantization.py +++ b/src/MaxText/layerwise_quantization.py @@ -38,7 +38,6 @@ import jax.numpy as jnp from absl import app -from flax import nnx from flax.linen import partitioning as nn_partitioning @@ -94,29 +93,15 @@ def load_and_quantize(self, rng: None | PRNGKeyType = None) -> None: self.quant.quant_mode = quantizations.get_quant_mode("convert") - if rng is None: - rng = jax.random.PRNGKey(0) - - dense_init_rng, moe_init_rng = jax.random.split(rng) - dense_params_rng, dense_dropout_rng = jax.random.split(dense_init_rng) - moe_params_rng, moe_dropout_rng = jax.random.split(moe_init_rng) - - dense_layer_rngs = nnx.Rngs(params=dense_params_rng, dropout=dense_dropout_rng) - moe_layer_rngs = nnx.Rngs(params=moe_params_rng, dropout=moe_dropout_rng) - layers = [ - deepseek.DeepSeekDenseLayer( + deepseek.DeepSeekDenseLayer( # pylint: disable=no-value-for-parameter config, - model_mode=common_types.MODEL_MODE_TRAIN, mesh=self._mesh, - rngs=dense_layer_rngs, quant=self.quant, ), - deepseek.DeepSeekMoELayer( + deepseek.DeepSeekMoELayer( # pylint: disable=no-value-for-parameter config, - model_mode=common_types.MODEL_MODE_TRAIN, mesh=self._mesh, - rngs=moe_layer_rngs, quant=self.quant, ), ] From 6ab219e52f3638b5b4e6e7219fdec7931b506ec9 Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Sat, 8 Nov 2025 00:11:36 +0400 Subject: [PATCH 04/10] Fix linter Signed-off-by: Vladimir Suvorov --- tests/pipeline_parallelism_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pipeline_parallelism_test.py b/tests/pipeline_parallelism_test.py index 43efb62ca..bc3cba240 100644 --- a/tests/pipeline_parallelism_test.py +++ b/tests/pipeline_parallelism_test.py @@ -231,6 +231,7 @@ def test_circular_deepseek_megablox_same_output_and_grad(self): sparse_matmul=False, capacity_factor=1, decoder_block="deepseek", + attention_type="mla", ) self.assert_pipeline_same_output_and_grad(config, single_pipeline_stage_class=deepseek.DeepSeekMoELayer) From 6175903fec7fa672bc9d713f907cce4a16c5aa6c Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Sat, 8 Nov 2025 00:23:58 +0400 Subject: [PATCH 05/10] Fix comments Signed-off-by: Vladimir Suvorov --- .../examples/grpo_llama3_1_8b_demo.ipynb | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb index 567e7a8d4..d9de3ab20 100644 --- a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb +++ b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb @@ -71,6 +71,24 @@ "Set up the training parameters. Defaults are hardcoded for Llama3.1-8B:" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Multi-host Pathways\n", + "\n", + "To run this demo on a multi-host Pathways setup:\n", + "- Set `use_pathways=True` in `rl.yml` (enabled by default).\n", + "- Override `trainer_devices_fraction` and `sampler_devices_fraction` in `config_argv` to split the mesh across hosts.\n", + "- Launch the Colab kernel on the controller host and export Pathways runtime variables (for example `JAX_PLATFORMS=proxy` and `ENABLE_PATHWAYS_PERSISTENCE=1`) before running training.\n", + "- Update `chips_per_vm` to match your slice topology; Pathways will shard trainer and rollout workers automatically.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -85,8 +103,7 @@ "\n", "# Hardcoded defaults for Llama3.1-8B\n", "MODEL_NAME = \"llama3.1-8b\"\n", - "TOKENIZER_PATH = \"meta-llama/Llama-3.1-8B-Instruct\"\n", - "HF_MODEL_NAME = \"meta-llama/Llama-3.1-8B-Instruct\"\n", + "HF_REPO_ID = \"meta-llama/Llama-3.1-8B-Instruct\"\n", "CHAT_TEMPLATE_PATH = \"src/MaxText/examples/chat_templates/gsm8k_rl.json\"\n", "\n", "# Required: Set these before running\n", @@ -165,8 +182,8 @@ " \"\", # argv[0] placeholder\n", " config_file,\n", " f\"model_name={MODEL_NAME}\",\n", - " f\"tokenizer_path={TOKENIZER_PATH}\",\n", - " f\"hf_model_name={HF_MODEL_NAME}\",\n", + " f\"tokenizer_path={HF_REPO_ID}\",\n", + " f\"hf_model_name={HF_REPO_ID}\",\n", " f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n", " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n", " f\"base_output_directory={OUTPUT_DIRECTORY}\",\n", @@ -180,18 +197,8 @@ " f\"chips_per_vm={CHIPS_PER_VM}\",\n", "]\n", "\n", - "print(\"πŸ“‹ Configuration parameters:\")\n", - "for arg in config_argv[2:]: # Skip argv[0] and config_file\n", - " key, value = arg.split(\"=\", 1)\n", - " # Mask sensitive values\n", - " if \"token\" in key.lower() or \"password\" in key.lower():\n", - " display_value = \"***\" if value else \"not set\"\n", - " else:\n", - " display_value = value[:60] + \"...\" if len(value) > 60 else value\n", - " print(f\" {key}={display_value}\")\n", - "\n", "# Initialize configuration\n", - "print(f\"\\nπŸ”§ Initializing configuration from: {config_file}\")\n", + "print(f\"πŸ”§ Initializing configuration from: {config_file}\")\n", "config = pyconfig.initialize(config_argv)\n", "max_utils.print_system_information()\n", "\n", From 5e0df9a7081527280052469c938b7e94f5078e4e Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Sat, 8 Nov 2025 00:24:23 +0400 Subject: [PATCH 06/10] Fix comments Signed-off-by: Vladimir Suvorov --- docs/tutorials/grpo_with_pathways.md | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/docs/tutorials/grpo_with_pathways.md b/docs/tutorials/grpo_with_pathways.md index c2ce421c7..895dbdb51 100644 --- a/docs/tutorials/grpo_with_pathways.md +++ b/docs/tutorials/grpo_with_pathways.md @@ -69,9 +69,4 @@ python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ --hf_access_token=$HF_TOKEN" ``` -The overview of the demo script ~/maxtext/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py` is as follows: - -1. We load a policy model and a reference model. Both are copies of `Llama3.1-70b-Instruct`. -2. Evaluate the policy model's performance on GSM8K math reasoning benchmark. -3. Train the policy model using GRPO with potentially different meshes for trainer and rollout depending on the parameters `TRAINER_DEVICES_FRACTION` and `SAMPLER_DEVICES_FRACTION`. If we set both of these to `1.0`, the entire (same) mesh will be used for both trainer and rollout. If we set say `TRAINER_DEVICES_FRACTION=0.5` and `SAMPLER_DEVICES_FRACTION=0.5`, the first half of the devices will be used for trainer and the second half will be used for rollout -4. Evaluate the policy model's performance on GSM8K math reasoning benchmark after the post-training with GRPO. +For an interactive walkthrough, open `src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb`. The notebook now delegates to `rl_train`, so you can reuse the same configuration flags shown above (including `trainer_devices_fraction` and `sampler_devices_fraction`) when scaling to multi-host Pathways or to larger checkpoints such as Llama3.1-70B. From fca89583554937d11141cb1bc4ed134015e10273 Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Tue, 11 Nov 2025 20:04:40 +0400 Subject: [PATCH 07/10] Restore pipeline test to main Signed-off-by: Vladimir Suvorov --- tests/pipeline_parallelism_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pipeline_parallelism_test.py b/tests/pipeline_parallelism_test.py index bc3cba240..43efb62ca 100644 --- a/tests/pipeline_parallelism_test.py +++ b/tests/pipeline_parallelism_test.py @@ -231,7 +231,6 @@ def test_circular_deepseek_megablox_same_output_and_grad(self): sparse_matmul=False, capacity_factor=1, decoder_block="deepseek", - attention_type="mla", ) self.assert_pipeline_same_output_and_grad(config, single_pipeline_stage_class=deepseek.DeepSeekMoELayer) From dd0c16a7b520982b2348a3d5ccf391ac27cd323c Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Wed, 12 Nov 2025 22:26:55 +0400 Subject: [PATCH 08/10] GRPO/GSPO Signed-off-by: Vladimir Suvorov --- .../examples/grpo_llama3_1_8b_demo.ipynb | 134 ++++++++++++++---- 1 file changed, 110 insertions(+), 24 deletions(-) diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb index d9de3ab20..16aedc246 100644 --- a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb +++ b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb @@ -37,8 +37,8 @@ "outputs": [], "source": [ "# Clone MaxText repository\n", - "!git clone https://github.com/AI-Hypercomputer/maxtext.git\n", - "%cd maxtext" + "!git clone https://github.com/AI-Hypercomputer/maxtext\n", + "%cd maxtext/src" ] }, { @@ -47,19 +47,57 @@ "metadata": {}, "outputs": [], "source": [ - "# Install dependencies\n", - "!chmod +x setup.sh\n", - "!./setup.sh\n", + "!bash tools/setup/setup.sh\n", + "%pip uninstall -y jax jaxlib libtpu\n", "\n", - "# Install GRPO-specific dependencies\n", - "!./src/MaxText/examples/install_tunix_vllm_requirement.sh\n", + "%pip install aiohttp==3.12.15\n", + "\n", + "# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.\n", + "%pip install keyring keyrings.google-artifactregistry-auth\n", + "\n", + "# Install vLLM for Jax and TPUs from the artifact registry\n", + "!VLLM_TARGET_DEVICE=\"tpu\" pip install --no-cache-dir --pre \\\n", + " --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \\\n", + " --extra-index-url https://pypi.org/simple/ \\\n", + " --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \\\n", + " --extra-index-url https://download.pytorch.org/whl/nightly/cpu \\\n", + " --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \\\n", + " --find-links https://storage.googleapis.com/libtpu-wheels/index.html \\\n", + " --find-links https://storage.googleapis.com/libtpu-releases/index.html \\\n", + " --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \\\n", + " --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \\\n", + " vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu\n", + "\n", + "# Install tpu-commons from the artifact registry\n", + "%pip install --no-cache-dir --pre \\\n", + " --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \\\n", + " --extra-index-url https://pypi.org/simple/ \\\n", + " --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \\\n", + " --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \\\n", + " tpu-commons==0.1.2\n", + "\n", + "%pip install numba==0.61.2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "\n", - "# Install additional requirements\n", - "%pip install --force-reinstall numpy==2.1.2\n", "%pip install nest_asyncio\n", "\n", "import nest_asyncio\n", - "nest_asyncio.apply() # Fix for Colab event loop" + "nest_asyncio.apply() # Fix for Colab event loop\n", + "\n", + "%cd maxtext/src/\n", + "\n", + "#Fix nnx problems\n", + "!pip uninstall flax \n", + "!pip uninstall qwix\n", + "!pip install flax \n", + "!pip install qwix" ] }, { @@ -97,19 +135,21 @@ "source": [ "# Configuration for GRPO training\n", "import os\n", + "import MaxText\n", "\n", "# Set up paths (adjust if needed)\n", - "MAXTEXT_REPO_ROOT = os.path.expanduser(\"~\") + \"/maxtext\"\n", - "\n", + "MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n", + "RUN_NAME=\"grpo_test\"\n", "# Hardcoded defaults for Llama3.1-8B\n", "MODEL_NAME = \"llama3.1-8b\"\n", "HF_REPO_ID = \"meta-llama/Llama-3.1-8B-Instruct\"\n", - "CHAT_TEMPLATE_PATH = \"src/MaxText/examples/chat_templates/gsm8k_rl.json\"\n", + "CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/examples/chat_templates/gsm8k_rl.json\"\n", + "LOSS_ALGO=\"gspo-token\"\n", "\n", "# Required: Set these before running\n", - "MODEL_CHECKPOINT_PATH = \"gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items\" # Update this!\n", - "OUTPUT_DIRECTORY = \"/tmp/grpo_output\" # Update this!\n", - "HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\") # Set HF_TOKEN environment variable\n", + "MODEL_CHECKPOINT_PATH = \"\" # Update this!\n", + "OUTPUT_DIRECTORY = \"/tmp/gpo_output\" # Update this!\n", + "HF_TOKEN = \"\" # Set HF_TOKEN environment variable\n", "\n", "# Optional: Override training parameters\n", "STEPS = 10 # Reduced for demo purposes\n", @@ -118,14 +158,15 @@ "NUM_GENERATIONS = 2\n", "GRPO_BETA = 0.08\n", "GRPO_EPSILON = 0.2\n", - "CHIPS_PER_VM = 4\n", + "CHIPS_PER_VM = 1\n", "\n", "print(f\"πŸ“ MaxText Home: {MAXTEXT_REPO_ROOT}\")\n", "print(f\"πŸ€– Model: {MODEL_NAME}\")\n", "print(f\"πŸ“¦ Checkpoint: {MODEL_CHECKPOINT_PATH}\")\n", "print(f\"πŸ’Ύ Output: {OUTPUT_DIRECTORY}\")\n", "print(f\"πŸ”‘ HF Token: {'βœ… Set' if HF_TOKEN else '❌ Missing - set HF_TOKEN env var'}\")\n", - "print(f\"πŸ“Š Steps: {STEPS}\")" + "print(f\"πŸ“Š Steps: {STEPS}\")\n", + "print(f\"Loss Algorithm : {LOSS_ALGO}\")" ] }, { @@ -140,7 +181,7 @@ "from pathlib import Path\n", "\n", "# Add MaxText to Python path\n", - "maxtext_path = Path(MAXTEXT_REPO_ROOT) / \"src\" / \"MaxText\"\n", + "maxtext_path = Path(MAXTEXT_REPO_ROOT) \n", "sys.path.insert(0, str(maxtext_path))\n", "\n", "from MaxText import pyconfig, max_utils\n", @@ -163,6 +204,51 @@ "print(f\"πŸ“ MaxText path: {maxtext_path}\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Build configuration for GRPO training\n", + "config_file = os.path.join(MAXTEXT_REPO_ROOT, \"configs/rl.yml\")\n", + "\n", + "# Verify chat template exists\n", + "if not os.path.exists(os.path.join(MAXTEXT_REPO_ROOT, CHAT_TEMPLATE_PATH)):\n", + " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n", + "\n", + "# Build argv list for pyconfig.initialize()\n", + "config_argv = [\n", + " \"\", # argv[0] placeholder\n", + " config_file,\n", + " f\"model_name={MODEL_NAME}\",\n", + " f\"tokenizer_path={HF_REPO_ID}\",\n", + " f\"run_name={RUN_NAME}\",\n", + " f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n", + " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n", + " f\"base_output_directory={OUTPUT_DIRECTORY}\",\n", + " f\"hf_access_token={HF_TOKEN}\",\n", + " f\"steps={STEPS}\",\n", + " f\"per_device_batch_size={PER_DEVICE_BATCH_SIZE}\",\n", + " f\"learning_rate={LEARNING_RATE}\",\n", + " f\"num_generations={NUM_GENERATIONS}\",\n", + " f\"grpo_beta={GRPO_BETA}\",\n", + " f\"grpo_epsilon={GRPO_EPSILON}\",\n", + " f\"chips_per_vm={CHIPS_PER_VM}\",\n", + " f\"loss_algo={LOSS_ALGO}\"\n", + "]\n", + "\n", + "# Initialize configuration\n", + "print(f\"πŸ”§ Initializing configuration from: {config_file}\")\n", + "config = pyconfig.initialize(config_argv)\n", + "max_utils.print_system_information()\n", + "\n", + "print(\"\\nβœ… Configuration initialized successfully\")\n", + "print(f\"πŸ“Š Training steps: {config.steps}\")\n", + "print(f\"πŸ“ Output directory: {config.base_output_directory}\")\n", + "print(f\"πŸ€– Model: {config.model_name}\")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -214,17 +300,17 @@ "metadata": {}, "outputs": [], "source": [ - "# Execute GRPO training\n", + "# Execute GRPO/GSPO training\n", "print(\"\\n\" + \"=\"*80)\n", - "print(\"πŸš€ Starting GRPO Training...\")\n", + "print(\"πŸš€ Starting Training...\")\n", "print(\"=\"*80)\n", - "\n", + "print(1)\n", "try:\n", " # Call the rl_train function (it handles everything internally)\n", " rl_train(config)\n", " \n", " print(\"\\n\" + \"=\"*80)\n", - " print(\"βœ… GRPO Training Completed Successfully!\")\n", + " print(\"βœ… Training Completed Successfully!\")\n", " print(\"=\"*80)\n", " print(f\"πŸ“ Checkpoints saved to: {config.checkpoint_dir}\")\n", " print(f\"πŸ“Š TensorBoard logs: {config.tensorboard_dir}\")\n", @@ -232,7 +318,7 @@ " \n", "except Exception as e:\n", " print(\"\\n\" + \"=\"*80)\n", - " print(\"❌ GRPO Training Failed!\")\n", + " print(\"❌Training Failed!\")\n", " print(\"=\"*80)\n", " print(f\"Error: {str(e)}\")\n", " import traceback\n", From 70a7ca95ff260f2e608f4f2c7608391aa04ed63b Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Wed, 12 Nov 2025 22:32:49 +0400 Subject: [PATCH 09/10] Fix Signed-off-by: Vladimir Suvorov --- src/MaxText/layerwise_quantization.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/MaxText/layerwise_quantization.py b/src/MaxText/layerwise_quantization.py index e9838b402..4b5fe045d 100644 --- a/src/MaxText/layerwise_quantization.py +++ b/src/MaxText/layerwise_quantization.py @@ -94,16 +94,8 @@ def load_and_quantize(self, rng: None | PRNGKeyType = None) -> None: self.quant.quant_mode = quantizations.get_quant_mode("convert") layers = [ - deepseek.DeepSeekDenseLayer( # pylint: disable=no-value-for-parameter - config, - mesh=self._mesh, - quant=self.quant, - ), - deepseek.DeepSeekMoELayer( # pylint: disable=no-value-for-parameter - config, - mesh=self._mesh, - quant=self.quant, - ), + deepseek.DeepSeekDenseLayer(config, mesh=self._mesh, quant=self.quant), + deepseek.DeepSeekMoELayer(config, mesh=self._mesh, quant=self.quant), ] layer_prefixes = ["dense_layers", "moe_layers"] num_moe_layers = config.num_decoder_layers - config.first_num_dense_layers From 40fea137b946fa820017954a9133c2f6e63d0212 Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Wed, 12 Nov 2025 23:18:51 +0400 Subject: [PATCH 10/10] Fix Signed-off-by: Vladimir Suvorov --- .../examples/grpo_llama3_1_8b_demo.ipynb | 45 ------------------- 1 file changed, 45 deletions(-) diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb index 16aedc246..e17221eaf 100644 --- a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb +++ b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb @@ -249,51 +249,6 @@ "print(f\"πŸ€– Model: {config.model_name}\")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Build configuration for GRPO training\n", - "# Using rl.yml as the base config (not grpo.yml)\n", - "config_file = os.path.join(MAXTEXT_REPO_ROOT, \"src/MaxText/configs/rl.yml\")\n", - "\n", - "# Verify chat template exists\n", - "if not os.path.exists(os.path.join(MAXTEXT_REPO_ROOT, CHAT_TEMPLATE_PATH)):\n", - " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n", - "\n", - "# Build argv list for pyconfig.initialize()\n", - "config_argv = [\n", - " \"\", # argv[0] placeholder\n", - " config_file,\n", - " f\"model_name={MODEL_NAME}\",\n", - " f\"tokenizer_path={HF_REPO_ID}\",\n", - " f\"hf_model_name={HF_REPO_ID}\",\n", - " f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n", - " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n", - " f\"base_output_directory={OUTPUT_DIRECTORY}\",\n", - " f\"hf_access_token={HF_TOKEN}\",\n", - " f\"steps={STEPS}\",\n", - " f\"per_device_batch_size={PER_DEVICE_BATCH_SIZE}\",\n", - " f\"learning_rate={LEARNING_RATE}\",\n", - " f\"num_generations={NUM_GENERATIONS}\",\n", - " f\"grpo_beta={GRPO_BETA}\",\n", - " f\"grpo_epsilon={GRPO_EPSILON}\",\n", - " f\"chips_per_vm={CHIPS_PER_VM}\",\n", - "]\n", - "\n", - "# Initialize configuration\n", - "print(f\"πŸ”§ Initializing configuration from: {config_file}\")\n", - "config = pyconfig.initialize(config_argv)\n", - "max_utils.print_system_information()\n", - "\n", - "print(\"\\nβœ… Configuration initialized successfully\")\n", - "print(f\"πŸ“Š Training steps: {config.steps}\")\n", - "print(f\"πŸ“ Output directory: {config.base_output_directory}\")\n", - "print(f\"πŸ€– Model: {config.model_name}\")" - ] - }, { "cell_type": "code", "execution_count": null,