diff --git a/src/MaxText/examples/reinforcement_learning_grpo.py b/src/MaxText/examples/reinforcement_learning_grpo.py
new file mode 100644
index 000000000..c2b84cdab
--- /dev/null
+++ b/src/MaxText/examples/reinforcement_learning_grpo.py
@@ -0,0 +1,1262 @@
+# Copyright 2023–2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=bare-except, consider-using-generator
+"""
+GRPO (Group Relative Policy Optimization) Tutorial
+===================================================
+
+This tutorial demonstrates training the Llama 3.1 8B model on the GSM8K math reasoning
+benchmark using Group Relative Policy Optimization (GRPO).
+
+What is GRPO?
+-------------
+GRPO is a Reinforcement Learning algorithm designed to enhance reasoning abilities
+of Large Language Models. It's a memory-efficient variant of PPO (Proximal Policy
+Optimization) that:
+ - Eliminates the need for a separate value function model (saves memory)
+ - Generates multiple responses per prompt (the "group")
+ - Evaluates responses using reward functions
+ - Calculates relative advantage based on group performance
+ - Updates the policy to improve future generations
+
+Why use GRPO?
+-------------
+GRPO can enhance your model's problem-solving skills on:
+ - Mathematical word problems (like GSM8K)
+ - Coding challenges
+ - Reasoning tasks
+ - Any task where you can define a reward function
+
+Architecture Overview:
+---------------------
+ - Tunix: Main library for GRPO training orchestration
+ - vLLM: Efficient inference engine for generating responses during training
+ - MaxText: Model implementation (supports Qwen3, Llama, Gemma, etc.)
+ - JAX/Flax: Backend for training
+
+Hardware Requirements:
+---------------------
+This tutorial uses a single host TPU VM (e.g., v6e-8 or v5p-8)
+
+Let's get started!
+"""
+
+# ==============================================================================
+# STEP 0: IMPORTS AND ENVIRONMENT SETUP
+# ==============================================================================
+
+# Standard library imports
+from pprint import pprint
+import functools
+import os
+import re
+
+# Third-party imports
+from tqdm.auto import tqdm
+import grain
+import humanize
+import jax
+from flax import nnx
+from flax.linen import partitioning as nn_partitioning
+import optax
+from orbax import checkpoint as ocp
+import tensorflow_datasets as tfds
+from transformers import AutoTokenizer
+
+# Tunix imports (GRPO training framework)
+from tunix.rl import rl_cluster as rl_cluster_lib
+from tunix.rl.rollout import base_rollout
+from tunix.rl.rollout.base_rollout import RolloutConfig
+from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner
+from tunix.sft import metrics_logger
+from tunix.models.llama3 import model as llama3_lib
+
+# MaxText imports (model implementation)
+from MaxText.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_PKG_DIR
+from MaxText import model_creation_utils
+from MaxText import pyconfig
+from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
+
+# Environment setup for vLLM
+# Skip JAX precompilation to speed up startup when using vLLM
+os.environ["SKIP_JAX_PRECOMPILE"] = "1"
+
+# For Colab/notebook environments (uncomment if needed):
+# import nest_asyncio
+# nest_asyncio.apply() # Fixes "This event loop is already running" error
+
+# Initialize JAX devices (TPU/GPU)
+jax.devices()
+
+# Global settings
+DEBUG = True # Set to True for detailed debug output during training
+HOME = os.path.join(os.path.expanduser("~"), "")
+print(f"Home directory (from Python): {HOME}")
+
+# ==============================================================================
+# INSTALLATION REQUIREMENTS
+# ==============================================================================
+# Before running this tutorial, ensure you have installed all required packages.
+# For detailed instructions, refer to:
+# https://maxtext.readthedocs.io/en/latest/tutorials/grpo.html
+#
+# Quick setup:
+# 1. Run the initial setup script:
+# bash tools/setup/setup.sh
+#
+# 2. Activate your virtual environment:
+# venv_name="maxtext_venv" # Replace with your venv name if different
+# source ~/$venv_name/bin/activate
+#
+# 3. Install vLLM and tpu-commons dependencies:
+# bash ~/maxtext/src/MaxText/examples/install_tunix_vllm_requirement.sh
+# Note: This installation may take several minutes. Monitor the logs for any errors.
+
+# ==============================================================================
+# STEP 1: CONFIGURE HYPERPARAMETERS
+# ==============================================================================
+#
+# This section defines all hyperparameters for the GRPO training pipeline.
+# These are organized into logical categories for easy understanding and tuning.
+#
+# Note: These are not "perfect" hyperparameters. For production results,
+# you may need to tune these values and train for longer.
+
+# ==============================================================================
+# DATA CONFIGURATION
+# ==============================================================================
+# Directories for storing training and test datasets
+TRAIN_DATA_DIR = os.path.join(HOME, "data", "train")
+TEST_DATA_DIR = os.path.join(HOME, "data", "test")
+if not os.path.exists(TRAIN_DATA_DIR):
+ os.makedirs(TRAIN_DATA_DIR)
+if not os.path.exists(TEST_DATA_DIR):
+ os.makedirs(TEST_DATA_DIR)
+
+# Fraction of training data to use (1.0 = 100%, use all data)
+# Set < 1.0 to create a validation split
+TRAIN_FRACTION = 1.0
+
+# ==============================================================================
+# MODEL & CHECKPOINT CONFIGURATION
+# ==============================================================================
+# Path to pre-trained model checkpoint (can be local or GCS path)
+# For Llama 3 8B, you'll need to convert the HuggingFace checkpoint to MaxText format
+# See: /maxtext/src/MaxText/utils/ckpt_conversion/to_maxtext.py
+MODEL_CHECKPOINT_PATH = "gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items"
+
+# Directory for TensorBoard logs (training metrics visualization)
+LOG_DIR = os.path.join(HOME, "content", "tensorboard", "grpo", "logs_llama3", "")
+if not os.path.exists(LOG_DIR):
+ os.makedirs(LOG_DIR)
+
+# Directory for JAX profiling traces (performance analysis)
+PROFILE_DIR = os.path.join(HOME, "content", "jax_traces", "grpo", "profiles_llama3", "")
+if not os.path.exists(PROFILE_DIR):
+ os.makedirs(PROFILE_DIR)
+
+# Directory for saving training checkpoints
+CKPT_DIR = os.path.join(HOME, "content", "ckpts_llama3", "")
+if not os.path.exists(CKPT_DIR):
+ os.makedirs(CKPT_DIR)
+
+# Checkpoint saving frequency (save every N steps)
+SAVE_INTERVAL_STEPS = 500
+
+# Maximum number of checkpoints to retain (older ones are deleted)
+MAX_TO_KEEP = 4
+
+# Random seed for reproducibility (data shuffling, sampling, etc.)
+SEED = 42
+
+# ==============================================================================
+# GRPO ALGORITHM PARAMETERS
+# ==============================================================================
+# Number of responses generated per prompt in each training step
+# This is the "G" (group size) in GRPO Algorithm 1
+# Larger values provide better advantage estimates but increase compute
+NUM_GENERATIONS = 2
+
+# Number of optimization iterations per batch (μ in GRPO Algorithm 1)
+# Higher values = more gradient steps per batch of data
+NUM_ITERATIONS = 1
+
+# KL divergence penalty coefficient (β in GRPO loss function)
+# Controls how much the policy can deviate from the reference model
+# Too low: policy may diverge too much; Too high: policy updates too conservative
+BETA = 0.08
+
+# PPO-style clipping parameter (ε in GRPO loss)
+# Prevents excessively large policy updates for training stability
+EPSILON = 0.2
+
+# ==============================================================================
+# GENERATION/SAMPLING PARAMETERS (During Training)
+# ==============================================================================
+# Maximum length of input prompts (tokens)
+MAX_PROMPT_LENGTH = 512
+
+# Maximum number of tokens to generate per response
+TOTAL_GENERATION_STEPS = 1024
+
+# Sampling temperature during training rollouts
+# Higher values (0.9) encourage diversity and exploration
+# This is important for GRPO to generate varied responses
+TEMPERATURE = 0.9
+
+# Top-p (nucleus) sampling parameter
+# 1.0 = consider all tokens in the distribution
+TOP_P = 1.0
+
+# Top-k sampling parameter
+# Only sample from the top K most likely tokens
+TOP_K = 50
+
+# ==============================================================================
+# TRAINING CONFIGURATION
+# ==============================================================================
+# Batch size per device (number of prompts processed together)
+BATCH_SIZE = 1
+
+# Number of batches to train on
+# Increase for better results (original: 3738, reduced for demo)
+NUM_BATCHES = 500 # 200
+
+# Number of batches to use for testing/evaluation
+# Keep low for quick evaluation (max 330 if batch_size=4)
+NUM_TEST_BATCHES = 200 # 200
+
+# Evaluate on validation set every N steps
+# (Not used if TRAIN_FRACTION = 1.0, no validation split)
+EVAL_EVERY_N_STEPS = 10
+
+# Number of times to iterate over the entire dataset
+NUM_EPOCHS = 1
+
+# Total number of training steps (computed from other params)
+MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)
+
+# ==============================================================================
+# OPTIMIZER & LEARNING RATE SCHEDULE
+# ==============================================================================
+# Peak learning rate for AdamW optimizer
+LEARNING_RATE = 3e-6
+
+# AdamW beta1 parameter (momentum for first moment estimates)
+B1 = 0.9
+
+# AdamW beta2 parameter (momentum for second moment estimates)
+B2 = 0.99
+
+# Weight decay coefficient for L2 regularization
+WEIGHT_DECAY = 0.1
+
+# Number of warmup steps for learning rate schedule
+# LR linearly increases from 0 to LEARNING_RATE over this period
+# Then cosine decays to 0 over remaining steps
+WARMUP_STEPS = int(0.1 * MAX_STEPS)
+
+# Maximum gradient norm for gradient clipping
+# Prevents exploding gradients and helps maintain stable KL divergence
+# Set to None to disable gradient clipping
+MAX_GRAD_NORM = 0.1
+
+# ==============================================================================
+# EVALUATION/INFERENCE CONFIGURATIONS
+# ==============================================================================
+# Different sampling strategies for evaluation
+GENERATION_CONFIGS = {
+ # Greedy decoding: deterministic, always picks most likely token
+ "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
+
+ # Standard sampling: balanced exploration
+ "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
+
+ # Liberal sampling: high diversity
+ "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
+}
+
+# ==============================================================================
+# REWARD FUNCTION PARAMETERS
+# ==============================================================================
+# Rewards for correct formatting
+REWARD_EXACT_FORMAT_MATCH = 3.0 # Perfect format match
+REWARD_WHITE_SPACE_FORMAT_MATCH = 1.5 # Match with whitespace differences
+REWARD_PARTIAL_FORMAT_MATCH = 0.5 # Partial format compliance
+
+# Rewards for answer correctness
+REWARD_RATIO_GUESS_TO_ANSWER_HIGH = 0.5 # Answer within 10% of correct value
+REWARD_RATIO_GUESS_TO_ANSWER_LOW = 0.25 # Answer within 20% of correct value
+
+# Penalties for mistakes
+PENALTY_INCORRECT_FORMAT = -0.5 # Wrong formatting
+PENALTY_INCORRECT_ANSWER = -1.0 # Wrong answer
+
+
+# ==============================================================================
+# STEP 2: DEFINE UTILITY FUNCTIONS
+# ==============================================================================
+#
+# Helper functions for monitoring training progress and system resources
+
+
+def show_hbm_usage():
+ """Displays memory usage per device."""
+ fmt_size = functools.partial(humanize.naturalsize, binary=True)
+
+ for d in jax.local_devices():
+ stats = d.memory_stats()
+ used = stats["bytes_in_use"]
+ limit = stats["bytes_limit"]
+ print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")
+
+
+# ==============================================================================
+# STEP 3: DATA PREPROCESSING & TOKENIZER SETUP
+# ==============================================================================
+#
+# This section:
+# 1. Loads the tokenizer for Llama 3.1 8B
+# 2. Defines special tokens for structured output (reasoning + answer format)
+# 3. Creates prompt templates for the GSM8K math reasoning task
+#
+# We instruct the model to use a specific format:
+# ...model's reasoning...
+# ...final numerical answer...
+#
+# This structured format helps with:
+# - Evaluating the reasoning process
+# - Extracting and verifying the final answer
+# - Providing clearer rewards during RL training
+
+# Load tokenizer for Llama 3.1 8B from HuggingFace
+model_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
+
+# Define special tokens for structured output format
+reasoning_start = ""
+reasoning_end = ""
+solution_start = ""
+solution_end = ""
+
+# System prompt that instructs the model on the expected format
+SYSTEM_PROMPT = f"""You are given a problem. Think about the problem and \
+provide your reasoning. Place it between {reasoning_start} and \
+{reasoning_end}. Then, provide the final answer (i.e., just one numerical \
+value) between {solution_start} and {solution_end}."""
+
+# Chat template format (plain text - will be formatted by tokenizer's chat template)
+TEMPLATE = """user
+{system_prompt}
+
+{question}
+model"""
+
+# ==============================================================================
+# STEP 4: DATASET CREATION
+# ==============================================================================
+#
+# We use GSM8K (Grade School Math 8K) - a dataset of grade school math word
+# problems requiring multi-step reasoning. Perfect for testing GRPO's ability
+# to improve reasoning capabilities!
+
+
+def extract_hash_answer(text: str) -> str | None:
+ """Extracts the answer from a string that contains '####'.
+
+ Args:
+ text: The string to extract the answer from.
+
+ Returns:
+ The extracted answer as a string, or None if '####' is not found.
+ """
+ if DEBUG:
+ print(f"Extracting answer from: {text}")
+ if "####" not in text:
+ return None
+ return text.split("####")[1].strip()
+
+
+def get_dataset(data_dir, split="train") -> grain.MapDataset:
+ """Gets and preprocesses the GSM8K dataset.
+
+ Args:
+ data_dir: The directory to download and store the dataset.
+ split: The dataset split to use (e.g., 'train', 'test').
+
+ Returns:
+ A grain.MapDataset containing the preprocessed data.
+ """
+ # Download data
+ if not os.path.exists(data_dir):
+ os.makedirs(data_dir)
+
+ data = tfds.data_source(
+ "gsm8k",
+ split=split,
+ data_dir=data_dir,
+ builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
+ download=True,
+ )
+
+ loaded_dataset = (
+ grain.MapDataset.source(data)
+ .shuffle(seed=SEED)
+ .map(
+ lambda x: {
+ # Prompts are passed to model forward pass
+ "prompts": TEMPLATE.format(
+ system_prompt=SYSTEM_PROMPT,
+ question=x["question"].decode("utf-8"),
+ ),
+ # Question and answer are passed to reward functions
+ "question": x["question"].decode("utf-8"),
+ "answer": extract_hash_answer(x["answer"].decode("utf-8")),
+ }
+ )
+ )
+ return loaded_dataset
+
+
+DATASET = get_dataset(TRAIN_DATA_DIR, "train").batch(BATCH_SIZE)[:NUM_BATCHES]
+
+if TRAIN_FRACTION == 1.0:
+ train_dataset = DATASET.repeat(NUM_EPOCHS)
+ val_dataset = None
+else:
+ train_dataset = DATASET[: int(len(DATASET) * TRAIN_FRACTION)]
+ train_dataset = train_dataset.repeat(NUM_EPOCHS)
+
+ val_dataset = DATASET[int(len(DATASET) * TRAIN_FRACTION) :].repeat(NUM_EPOCHS)
+
+test_dataset = get_dataset(TEST_DATA_DIR, "test").batch(BATCH_SIZE)[:NUM_TEST_BATCHES]
+
+
+# Debug: Print sample batch to verify data preprocessing
+if DEBUG:
+ print("Sample batch from dataset:")
+ for ele in train_dataset[:1]:
+ pprint(ele)
+
+
+# ==============================================================================
+# STEP 5: LOAD POLICY AND REFERENCE MODELS
+# ==============================================================================
+#
+# GRPO requires TWO models:
+#
+# 1. POLICY MODEL (Actor):
+# - This is the model being trained
+# - Weights are updated during training
+# - Generates responses during rollouts
+#
+# 2. REFERENCE MODEL:
+# - Frozen copy of the original model
+# - Used to compute KL divergence penalty
+# - Prevents the policy from deviating too far from original behavior
+# - Ensures training stability
+#
+# Training Strategy:
+# ------------------
+# - This script uses FULL model training (all parameters updated)
+# - For memory efficiency, you could use LoRA (Low-Rank Adaptation):
+# * Freeze base model weights
+# * Only train small LoRA adapters
+# * Significantly reduces memory usage
+#
+# Precision:
+# ---------
+# - Using bfloat16 for memory efficiency
+# - For even lower precision, consider Qwix for Quantization-Aware Training
+
+print("HBM usage before loading models:")
+show_hbm_usage()
+
+
+# ### Helper function to create MaxText models
+
+
+def get_ref_maxtext_model(config):
+ """Creates and returns a TunixMaxTextAdapter model and mesh.
+
+ Args:
+ config: The model configuration.
+
+ Returns:
+ A tuple containing the TunixMaxTextAdapter model and the mesh.
+ """
+
+ model, this_mesh = model_creation_utils.create_nnx_model(config)
+ with this_mesh:
+ tunix_model = TunixMaxTextAdapter(
+ base_model=model,
+ )
+
+ this_model_config = llama3_lib.ModelConfig.llama3_1_8b()
+ tunix_model.config = this_model_config
+
+ return tunix_model, this_mesh
+
+
+model_config = llama3_lib.ModelConfig.llama3_1_8b()
+
+# Configure and load the reference model
+# Note: pass the path to your scanned checkpoint for "load_parameters_path".
+# To create a scanned checkpoint, you can use /maxtext/src/MaxText/utils/ckpt_conversion/to_maxtext.py
+config_ref = pyconfig.initialize(
+ [
+ "",
+ f"{HOME}/maxtext/src/MaxText/configs/base.yml",
+ ],
+ base_output_directory="dummy", # This is not used in Tunix.
+ run_name="test-tunix-maxtext-llama3.1-8b",
+ tokenizer_type="huggingface",
+ tokenizer_path="meta-llama/Llama-3.1-8B",
+ load_parameters_path=MODEL_CHECKPOINT_PATH,
+ per_device_batch_size=1,
+ max_prefill_predict_length=4,
+ max_target_length=1024,
+ steps=10,
+ async_checkpointing="false",
+ model_name="llama3.1-8b",
+ checkpoint_period=5,
+ skip_jax_distributed_system="true",
+ weight_dtype="bfloat16",
+ attention="dot_product",
+ remat_policy="custom",
+ decoder_layer_input="offload",
+ query_proj="offload",
+ key_proj="offload",
+ value_proj="offload",
+)
+
+llama3_8b, mesh = get_ref_maxtext_model(config_ref)
+
+llama3_8b.config = model_config
+
+nnx.display(llama3_8b)
+
+
+if DEBUG:
+ print("Model initialized successfully")
+ print(f"Model mesh shape: {mesh.shape}")
+ print(f"Model config: {model_config}")
+
+ # Sanity check that weights are loaded correctly
+ _maxtext_state_flatten = nnx.state(llama3_8b).flat_state()
+ maxtext_state_flatten = {".".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten}
+ print(
+ f"maxtext_state_flatten[base.token_embedder.embedding].value="
+ f"{maxtext_state_flatten['base.token_embedder.embedding'].value}"
+ )
+
+
+# See the memory use after loading the reference model:
+print("HBM usage after loading ref model:")
+show_hbm_usage()
+
+
+# Configure and load the policy model
+config_policy = pyconfig.initialize(
+ [
+ "",
+ f"{HOME}/maxtext/src/MaxText/configs/base.yml",
+ ],
+ base_output_directory="dummy", # This is not used in Tunix.
+ run_name="test-tunix-maxtext-llama3.1-8b", # This is not used in Tunix.
+ tokenizer_type="huggingface",
+ tokenizer_path="meta-llama/Llama-3.1-8B",
+ load_parameters_path=MODEL_CHECKPOINT_PATH,
+ per_device_batch_size=1,
+ max_prefill_predict_length=4,
+ max_target_length=1024,
+ steps=10,
+ async_checkpointing="false",
+ model_name="llama3.1-8b",
+ checkpoint_period=5,
+ skip_jax_distributed_system="true",
+ weight_dtype="bfloat16",
+ attention="dot_product",
+ remat_policy="custom",
+ decoder_layer_input="offload",
+ query_proj="offload",
+ key_proj="offload",
+ value_proj="offload",
+)
+llama3_8b_policy, mesh_policy = get_ref_maxtext_model(config_policy)
+
+llama3_8b_policy.config = model_config
+
+nnx.display(llama3_8b_policy)
+
+if DEBUG:
+ print("Model initialized successfully")
+ print(f"Model mesh shape: {mesh_policy.shape}")
+
+ # Sanity check that weights are loaded correctly
+ _maxtext_state_flatten = nnx.state(llama3_8b_policy).flat_state()
+ maxtext_state_flatten = {".".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten}
+ print(
+ f"maxtext_state_flatten[base.token_embedder.embedding].value="
+ f"{maxtext_state_flatten['base.token_embedder.embedding'].value}"
+ )
+
+# See memory usage after loading the policy model:
+print("HBM usage after loading policy model:")
+show_hbm_usage()
+
+
+# ==============================================================================
+# STEP 6: DEFINE REWARD FUNCTIONS
+# ==============================================================================
+#
+# Reward functions are the heart of GRPO - they tell the model what behavior
+# to reinforce. We define FOUR reward functions for the GSM8K task:
+#
+# 1. match_format_exactly: Rewards exact format compliance
+# - Checks if output has ... and ...
+# - Reward: +3.0 points
+#
+# 2. match_format_approximately: Rewards partial format compliance
+# - Checks if special tokens appear once each (no duplicates/missing)
+# - Reward: +0.5 per correct token, -0.5 penalty per incorrect
+#
+# 3. check_answer: Rewards correct numerical answers
+# - Exact match: +3.0
+# - With whitespace differences: +1.5
+# - Within 10% of correct: +0.5
+# - Within 20% of correct: +0.25
+# - Wrong answer: -1.0
+#
+# 4. check_numbers: Fallback for extracting numbers from verbose answers
+# - Extracts first number from section
+# - Exact match: +1.5
+#
+# Inspiration: https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb
+
+# Regular expression to match the expected format
+match_format = re.compile(
+ rf"^[\s]{{0,}}" rf"{reasoning_start}.+?{reasoning_end}.*?" rf"{solution_start}(.+?){solution_end}" rf"[\s]{{0,}}$",
+ flags=re.MULTILINE | re.DOTALL,
+)
+
+# Test the regex (optional verification)
+match_format.search(
+ f"{reasoning_start}Let me" f" think!{reasoning_end}{solution_start}2{solution_end}",
+)
+
+
+# --- Reward Function 1: Exact Format Match ---
+
+
+def match_format_exactly(prompts, completions, **kargs):
+ """Rewards completions that exactly match the specified format.
+
+ Args:
+ prompts: The prompts used to generate completions.
+ completions: The generated completions.
+ **kargs: Additional keyword arguments.
+
+ Returns:
+ A list of scores for each completion.
+ """
+ scores = []
+ for completion in completions:
+ score = 0
+ response = completion
+ # Match if format is seen exactly!
+ if match_format.search(response) is not None:
+ score += REWARD_EXACT_FORMAT_MATCH
+ scores.append(score)
+ return scores
+
+
+# --- Reward Function 2: Approximate Format Match ---
+
+
+def match_format_approximately(prompts, completions, **kargs):
+ """Rewards completions that approximately match the specified format.
+
+ Args:
+ prompts: The prompts used to generate completions.
+ completions: The generated completions.
+ **kargs: Additional keyword arguments.
+
+ Returns:
+ A list of scores for each completion.
+ """
+ scores = []
+
+ for completion in completions:
+ score = 0
+ response = completion
+ # Count how many keywords are seen - we penalize if too many!
+ # If we see 1, then plus some points!
+ score += REWARD_PARTIAL_FORMAT_MATCH if response.count(reasoning_start) == 1 else PENALTY_INCORRECT_FORMAT
+ score += REWARD_PARTIAL_FORMAT_MATCH if response.count(reasoning_end) == 1 else PENALTY_INCORRECT_FORMAT
+ score += REWARD_PARTIAL_FORMAT_MATCH if response.count(solution_start) == 1 else PENALTY_INCORRECT_FORMAT
+ score += REWARD_PARTIAL_FORMAT_MATCH if response.count(solution_end) == 1 else PENALTY_INCORRECT_FORMAT
+ scores.append(score)
+ return scores
+
+
+# --- Reward Function 3: Answer Correctness Check ---
+#
+# This function rewards correct answers with partial credit for close answers
+
+
+def check_answer(prompts, completions, answer, **kargs):
+ """Checks if the answer in the completion is correct and rewards accordingly.
+
+ Args:
+ prompts: The prompts used to generate completions.
+ completions: The generated completions.
+ answer: The ground truth answers.
+ **kargs: Additional keyword arguments.
+
+ Returns:
+ A list of scores for each completion.
+ """
+ responses = completions
+
+ extracted_responses = [guess.group(1) if (guess := match_format.search(r)) is not None else None for r in responses]
+
+ scores = []
+ for guess, true_answer in zip(extracted_responses, answer):
+ score = 0
+ if guess is None:
+ scores.append(0)
+ continue
+ # Correct answer gets 3 points!
+ if guess == true_answer:
+ score += REWARD_EXACT_FORMAT_MATCH
+ # Match if spaces are seen
+ elif guess.strip() == true_answer.strip():
+ score += REWARD_WHITE_SPACE_FORMAT_MATCH
+ else:
+ # We also reward it if the answer is close via ratios!
+ # Ie if the answer is within some range, reward it!
+ try:
+ ratio = float(guess) / float(true_answer)
+ if 0.9 <= ratio <= 1.1:
+ score += REWARD_RATIO_GUESS_TO_ANSWER_HIGH
+ elif 0.8 <= ratio <= 1.2:
+ score += REWARD_RATIO_GUESS_TO_ANSWER_LOW
+ else:
+ score += PENALTY_INCORRECT_ANSWER # Penalize wrong answers
+ except (ValueError, TypeError, ZeroDivisionError):
+ score += PENALTY_INCORRECT_FORMAT # Penalize
+ scores.append(score)
+ return scores
+
+
+# --- Reward Function 4: Number Extraction Fallback ---
+#
+# Sometimes the answer section contains text instead of just a number.
+# This function extracts the first number found and checks correctness.
+# Useful when the model provides verbose answers like "The answer is 42"
+
+# Regex to extract the first number from the answer section
+match_numbers = re.compile(rf"{solution_start}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL)
+match_numbers.findall(f"{solution_start} 0.34 {solution_end}")
+
+
+def check_numbers(prompts, completions, answer, **kargs):
+ """Extracts numbers from completions and rewards if they match the answer.
+
+ Args:
+ prompts: The prompts used to generate completions.
+ completions: The generated completions.
+ answer: The ground truth answers.
+ **kargs: Additional keyword arguments.
+
+ Returns:
+ A list of scores for each completion.
+ """
+ question = kargs["question"]
+ responses = completions
+
+ extracted_responses = [guess.group(1) if (guess := match_numbers.search(r)) is not None else None for r in responses]
+
+ scores = []
+ if DEBUG:
+ print("START ============================")
+ print(f"Question: {question[0]}")
+ print(f"Answer: {answer[0]}")
+ print(f"Response: {responses[0]}")
+ print(f"Extracted: {extracted_responses[0]}")
+ print("END ==============================")
+ for guess, true_answer in zip(extracted_responses, answer):
+ if guess is None:
+ scores.append(0)
+ continue
+ # Convert to numbers
+ try:
+ true_answer = float(true_answer.strip())
+ guess = float(guess.strip())
+ scores.append(1.5 if guess == true_answer else 0.0)
+ except (ValueError, TypeError):
+ scores.append(0)
+ continue
+ return scores
+
+
+# ==============================================================================
+# STEP 7: DEFINE EVALUATION FUNCTIONS
+# ==============================================================================
+#
+# Evaluation helps us measure model performance before and after training.
+# We'll run evaluation both BEFORE and AFTER GRPO training to measure improvement.
+#
+# Evaluation Metrics:
+# ------------------
+# QUANTITATIVE:
+# 1. Answer Accuracy: % of samples with exact correct numerical answer
+# 2. Answer Partial Accuracy: % of samples where answer is within ±10% of truth
+# 3. Format Accuracy: % of samples with correct and format
+#
+# QUALITATIVE:
+# - We can also sample and manually inspect specific model outputs
+# - Useful for understanding HOW the model's reasoning improves
+#
+# The evaluation functions:
+# - generate_responses(): Uses vLLM to generate model outputs
+# - score_responses(): Applies reward functions to score outputs
+# - evaluate(): Runs full evaluation pipeline and returns metrics
+
+
+def generate_responses(
+ prompts,
+ rl_cluster,
+ num_passes=1,
+ temperature=0.7,
+ top_k=50,
+ top_p=0.95,
+):
+ """
+ Generate responses for a batch of prompts across multiple passes.
+
+ Args:
+ prompts: List of prompts to generate responses for
+ rl_cluster: Model cluster for generation
+ num_passes: Number of generation passes
+ temperature: Sampling temperature
+ top_k: Top-k sampling parameter
+ top_p: Top-p sampling parameter
+
+ Returns:
+ List of lists containing responses for each prompt across passes
+ """
+ multiple_call_responses = [[] for _ in range(len(prompts))]
+
+ for p in range(num_passes):
+ responses = rl_cluster.rollout.generate(
+ prompts,
+ rollout_config=RolloutConfig(
+ max_tokens_to_generate=TOTAL_GENERATION_STEPS,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ ),
+ )
+ responses = responses.text
+
+ if DEBUG:
+ print(f"Pass {p+1}/{num_passes}, responses: {responses}")
+
+ for idx, response in enumerate(responses):
+ multiple_call_responses[idx].append(response)
+
+ return multiple_call_responses
+
+
+def score_responses(question, responses, answer):
+ """
+ Score a set of responses for a single question.
+
+ Args:
+ question: The evaluation question
+ responses: List of generated responses for this question
+ answer: The correct answer
+
+ Returns:
+ Tuple of (is_correct, is_partially_correct, has_correct_format)
+ """
+ if DEBUG:
+ print("========================================")
+ print(f"Evaluation Question: {question}")
+ print(f"Evaluation Answer: {answer}")
+ print(f"Evaluation Responses: {responses}")
+ print("========================================")
+
+ is_correct = False
+ is_partially_correct = False
+ has_correct_format = False
+
+ for response in responses:
+ # Extract numerical response
+ extracted_response = guess.group(1) if (guess := match_numbers.search(response)) is not None else "-1000000"
+
+ if DEBUG:
+ print(f"Evaluation extracted_response: {extracted_response}")
+
+ # Check exact correctness
+ try:
+ if float(extracted_response.strip()) == float(answer.strip()):
+ is_correct = True
+
+ # Check partial correctness (within 10%)
+ ratio = float(extracted_response.strip()) / float(answer.strip())
+ if 0.9 <= ratio <= 1.1:
+ is_partially_correct = True
+ except (ValueError, TypeError, ZeroDivisionError) as e:
+ if DEBUG:
+ print(f"Evaluation Exception: {e}")
+ print("SKIPPED")
+
+ # Check format correctness
+ if match_format.search(response) is not None:
+ has_correct_format = True
+
+ # Early exit if all criteria are met
+ if is_correct and is_partially_correct and has_correct_format:
+ break
+
+ return is_correct, is_partially_correct, has_correct_format
+
+
+def evaluate(
+ dataset,
+ rl_cluster,
+ temperature=0.7,
+ top_k=50,
+ top_p=0.95,
+ num_passes=1,
+ corr_lst=False,
+ make_lst=False,
+):
+ """
+ Computes accuracy and percentage of outputs matching the format.
+
+ Args:
+ dataset: The evaluation dataset
+ rl_cluster: Model cluster for generation
+ temperature: Sampling temperature
+ top_k: Top-k sampling parameter
+ top_p: Top-p sampling parameter
+ num_passes: Number of generation passes
+ corr_lst: If True, only include correct responses in the list
+ make_lst: If True, return a list of (question, answer, responses)
+
+ Returns:
+ Tuple of statistics and optionally the response list
+ """
+ response_lst = []
+ corr = 0
+ partially_corr = 0
+ corr_format = 0
+ total = 0
+
+ for batch in tqdm(dataset):
+ answers = batch["answer"]
+ questions = batch["question"]
+ prompts = batch["prompts"]
+
+ # Generate responses for all prompts in the batch
+ multiple_call_responses = generate_responses(
+ prompts=prompts,
+ rl_cluster=rl_cluster,
+ num_passes=num_passes,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ )
+
+ # Score each question-answer pair
+ for question, responses, answer in zip(questions, multiple_call_responses, answers):
+ is_correct, is_partially_correct, has_correct_format = score_responses(
+ question=question,
+ responses=responses,
+ answer=answer,
+ )
+
+ # Update counters
+ if is_correct:
+ corr += 1
+ if corr_lst and make_lst:
+ response_lst.append((question, answer, responses))
+ else:
+ if not corr_lst and make_lst:
+ response_lst.append((question, answer, responses))
+
+ if is_partially_correct:
+ partially_corr += 1
+
+ if has_correct_format:
+ corr_format += 1
+
+ total += 1
+
+ # Print progress every 10 items
+ if total % 10 == 0:
+ print(
+ f"===> {corr=}, {total=}, {corr / total * 100=}, "
+ f"{partially_corr / total * 100=}, {corr_format / total * 100=}"
+ )
+
+ # Prepare return values
+ to_return = (
+ corr,
+ total,
+ corr / total * 100,
+ partially_corr / total * 100,
+ corr_format / total * 100,
+ )
+
+ if make_lst:
+ return to_return, response_lst
+ return to_return
+
+
+# ==============================================================================
+# STEP 8: MAIN TRAINING PIPELINE
+# ==============================================================================
+#
+# The main() function orchestrates the entire GRPO training workflow:
+#
+# 1. Setup Infrastructure:
+# - Configure checkpointing (save model every N steps)
+# - Setup metrics logging (TensorBoard)
+# - Configure profiling
+#
+# 2. Create Optimizer:
+# - AdamW optimizer with warmup + cosine decay learning rate schedule
+# - Gradient clipping for stability
+#
+# 3. Setup RL Cluster:
+# - Combines policy model, reference model, and tokenizer
+# - Configures vLLM for rollout (response generation)
+# - Sets up mesh for distributed training
+#
+# 4. Initialize GRPO Learner:
+# - Combines RL cluster with reward functions
+# - Configures GRPO-specific hyperparameters (beta, epsilon, etc.)
+#
+# 5. Pre-Training Evaluation:
+# - Measure baseline performance before training
+#
+# 6. Training Loop:
+# - GRPO trainer runs for MAX_STEPS
+# - Each step: generate responses → compute rewards → update policy
+#
+# 7. Post-Training Evaluation:
+# - Measure final performance to see improvement
+#
+# Let's begin!
+
+def main():
+ # --- 1. Setup Infrastructure ---
+
+ # Checkpoint manager: saves model weights periodically
+ checkpointing_options = ocp.CheckpointManagerOptions(save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP)
+
+ # Metrics logger: tracks training metrics for TensorBoard
+ metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=LOG_DIR, flush_every_n_steps=20)
+
+ # Print TensorBoard command for monitoring
+ print(f"TensorBoard logs directory: {LOG_DIR}")
+ print(f"tensorboard --logdir {LOG_DIR} --port=8086")
+
+ # --- 2. Create Optimizer with Learning Rate Schedule ---
+
+ # AdamW optimizer with warmup + cosine decay schedule
+ # LR starts at 0, increases to LEARNING_RATE over WARMUP_STEPS,
+ # then decreases to 0 following a cosine curve
+ optimizer = optax.adamw(
+ learning_rate=optax.schedules.warmup_cosine_decay_schedule(
+ init_value=0.0, # Start LR at zero
+ peak_value=LEARNING_RATE, # Peak LR after warmup
+ warmup_steps=WARMUP_STEPS, # Linear warmup period
+ decay_steps=MAX_STEPS, # Total steps for cosine decay
+ end_value=0.0, # End LR at zero
+ ),
+ b1=B1, # Adam beta1 (momentum)
+ b2=B2, # Adam beta2 (variance)
+ weight_decay=WEIGHT_DECAY, # L2 regularization
+ )
+
+ # Add gradient clipping for training stability
+ # Prevents exploding gradients and helps control KL divergence
+ if MAX_GRAD_NORM is not None:
+ optimizer = optax.chain(
+ optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
+ optimizer,
+ )
+
+ # --- 3. Setup RL Cluster Configuration ---
+
+ # The RL Cluster manages three roles:
+ # - ACTOR (policy model): generates responses and gets trained
+ # - REFERENCE: frozen model for KL divergence computation
+ # - ROLLOUT: vLLM engine for efficient generation during training
+ cluster_config = rl_cluster_lib.ClusterConfig(
+ role_to_mesh={
+ rl_cluster_lib.Role.ACTOR: mesh, # Policy model mesh
+ rl_cluster_lib.Role.REFERENCE: mesh, # Reference model mesh
+ rl_cluster_lib.Role.ROLLOUT: mesh, # vLLM rollout mesh
+ },
+ rollout_engine="vllm", # Use vLLM for fast generation
+ offload_to_cpu=False, # Keep everything on TPU/GPU
+ training_config=rl_cluster_lib.RLTrainingConfig(
+ actor_optimizer=optimizer, # Optimizer for policy model
+ eval_every_n_steps=EVAL_EVERY_N_STEPS, # Validation frequency
+ max_steps=MAX_STEPS, # Total training steps
+ gradient_accumulation_steps=None, # No gradient accumulation
+ metrics_logging_options=metrics_logging_options, # TensorBoard logging
+ checkpoint_root_directory=CKPT_DIR, # Checkpoint save dir
+ checkpointing_options=checkpointing_options, # Checkpoint frequency
+ ),
+ rollout_config=base_rollout.RolloutConfig(
+ max_tokens_to_generate=TOTAL_GENERATION_STEPS, # Max response length
+ max_prompt_length=MAX_PROMPT_LENGTH, # Max input length
+ kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256, # Cache size
+ temperature=TEMPERATURE, # Sampling temperature
+ top_p=TOP_P, # Nucleus sampling
+ top_k=TOP_K, # Top-k sampling
+ ),
+ rollout_vllm_model_version="meta-llama/Llama-3.1-8B", # HuggingFace model ID for vLLM
+ rollout_vllm_hbm_utilization=0.2, # vLLM memory usage (20%)
+ rollout_vllm_tpu_backend_type="jax", # Use JAX backend for vLLM
+ )
+
+ # --- 4. Initialize GRPO Configuration ---
+
+ # GRPO-specific hyperparameters
+ grpo_config = GrpoConfig(
+ num_generations=NUM_GENERATIONS, # Responses per prompt (group size)
+ num_iterations=NUM_ITERATIONS, # Optimization iterations per batch
+ beta=BETA, # KL divergence penalty coefficient
+ epsilon=EPSILON, # PPO-style clipping parameter
+ )
+
+ # Create RL Cluster: combines models and configuration
+ rl_cluster = rl_cluster_lib.RLCluster(
+ actor=llama3_8b_policy, # Policy model (trainable)
+ reference=llama3_8b, # Reference model (frozen)
+ tokenizer=model_tokenizer, # Tokenizer for both models
+ cluster_config=cluster_config, # Cluster configuration
+ )
+
+ # Create GRPO Trainer: combines RL cluster with reward functions
+ grpo_trainer = GrpoLearner(
+ rl_cluster=rl_cluster,
+ reward_fns=[ # List of reward functions to use
+ match_format_exactly, # Reward 1: Exact format match
+ match_format_approximately, # Reward 2: Approximate format match
+ check_answer, # Reward 3: Answer correctness
+ check_numbers, # Reward 4: Number extraction fallback
+ ],
+ grpo_config=grpo_config, # GRPO hyperparameters
+ )
+
+ # Debug: Test vLLM generation (optional sanity check)
+ if DEBUG:
+ print("Testing vLLM generation...")
+ output = rl_cluster.rollout.generate(
+ ["The capital of France is"],
+ rollout_config=RolloutConfig(max_tokens_to_generate=64, temperature=0.1),
+ )
+ print(f"vLLM test output: {output}")
+
+ # --- 5. Pre-Training Evaluation ---
+ #
+ # Evaluate model BEFORE training to establish a baseline
+ # This helps us measure how much GRPO improves the model
+
+ print("\n" + "="*80)
+ print("EVALUATING MODEL BEFORE GRPO TRAINING")
+ print("="*80 + "\n")
+
+ # pylint: disable=unbalanced-tuple-unpacking
+ (corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
+ test_dataset,
+ rl_cluster,
+ **GENERATION_CONFIGS["greedy"], # Use greedy decoding for deterministic eval
+ )
+ print(f"\nPre-Training Results:")
+ print(f" Correct: {corr}/{total}")
+ print(f" Answer Accuracy: {accuracy:.2f}%")
+ print(f" Partial Accuracy: {partial_accuracy:.2f}%")
+ print(f" Format Accuracy: {format_accuracy:.2f}%\n")
+
+ # --- 6. Training Loop ---
+ #
+ # This is where the magic happens! GRPO training loop:
+ # For each batch:
+ # 1. Generate multiple responses per prompt (using vLLM)
+ # 2. Score responses using reward functions
+ # 3. Compute advantages (how much better than group average)
+ # 4. Update policy to increase probability of high-reward responses
+ # 5. Apply KL penalty to prevent drift from reference model
+
+ print("="*80)
+ print("STARTING GRPO TRAINING")
+ print("="*80 + "\n")
+
+ # Start JAX profiler for performance analysis
+ # Uncomment to enable profiling (note: generates large trace files for long training runs)
+ # jax.profiler.start_trace(PROFILE_DIR)
+
+ # Run training with proper mesh and axis rules for distributed training
+ with mesh, nn_partitioning.axis_rules(config_policy.logical_axis_rules):
+ grpo_trainer.train(DATASET)
+
+ # Stop profiler (uncomment if profiling is enabled above)
+ # jax.profiler.stop_trace()
+
+ print("\n" + "="*80)
+ print("TRAINING COMPLETE")
+ print("="*80 + "\n")
+
+ # Check memory usage after training
+ print("HBM usage after training:")
+ show_hbm_usage()
+
+ # --- 7. Post-Training Evaluation ---
+ #
+ # Evaluate model AFTER training to measure improvement
+
+ print("\n" + "="*80)
+ print("EVALUATING MODEL AFTER GRPO TRAINING")
+ print("="*80 + "\n")
+
+ # pylint: disable=unbalanced-tuple-unpacking
+ (corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
+ test_dataset,
+ rl_cluster,
+ **GENERATION_CONFIGS["greedy"],
+ )
+ print(f"\nPost-Training Results:")
+ print(f" Correct: {corr}/{total}")
+ print(f" Answer Accuracy: {accuracy:.2f}%")
+ print(f" Partial Accuracy: {partial_accuracy:.2f}%")
+ print(f" Format Accuracy: {format_accuracy:.2f}%\n")
+
+ print("="*80)
+ print("GRPO TUTORIAL COMPLETE!")
+ print("="*80)
+
+
+if __name__ == "__main__":
+ main()