From 033fa6037bcf5f58585c6a59226861461c55c89b Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 31 Oct 2025 12:31:55 -0700 Subject: [PATCH 01/24] Add LanguageReward for training models to think in target language This commit introduces a new reward function that encourages models to think in a specific target language within their tags. Key changes: - Add LanguageReward class to src/forge/data/rewards.py - Uses langid for language detection - Configurable target language (ISO 639-1 codes) - Returns full_reward for language match, no_match_reward otherwise - Raises helpful error if langid not installed - Add comprehensive unit tests in tests/unit_tests/rl/test_language_reward.py - Tests for multiple languages (English, Japanese, Chinese, Spanish, etc.) - Tests for edge cases and error handling - All 28 tests passing - Create sandbox/grpo_language/ app for experimentation - Extends apps/grpo/ with LanguageReward - Hardcoded to Japanese (ja) as default target language - Includes README with usage instructions - Config file for Qwen3-1.7B model Implementation details: - Extracts text from tags for analysis - Concatenates multiple thinking blocks for language detection - Compatible with existing MathReward and ThinkingReward - Does not add langid to requirements.txt (optional dependency) Usage: python -m sandbox.grpo_language.main --config sandbox/grpo_language/qwen3_1_7b.yaml Note: Requires 'pip install langid' before use --- sandbox/grpo_language/README.md | 81 ++++ sandbox/grpo_language/main.py | 498 ++++++++++++++++++++ sandbox/grpo_language/qwen3_1_7b.yaml | 151 ++++++ src/forge/data/rewards.py | 79 ++++ tests/unit_tests/rl/test_language_reward.py | 287 +++++++++++ 5 files changed, 1096 insertions(+) create mode 100644 sandbox/grpo_language/README.md create mode 100644 sandbox/grpo_language/main.py create mode 100644 sandbox/grpo_language/qwen3_1_7b.yaml create mode 100644 tests/unit_tests/rl/test_language_reward.py diff --git a/sandbox/grpo_language/README.md b/sandbox/grpo_language/README.md new file mode 100644 index 000000000..46c573cb8 --- /dev/null +++ b/sandbox/grpo_language/README.md @@ -0,0 +1,81 @@ +# GRPO with Language Reward + +This sandbox app demonstrates using GRPO training with a language reward that encourages the model to think in a specific target language. + +## Overview + +This app extends the standard GRPO training (from `apps/grpo/`) by adding a `LanguageReward` that evaluates whether the model's thinking (text within `` tags) is in the target language. + +## Key Features + +- **Multi-objective training**: Combines three rewards: + - `MathReward`: Evaluates correctness of math answers + - `ThinkingReward`: Encourages use of thinking tags + - `LanguageReward`: Rewards thinking in target language (Japanese by default) + +- **Language detection**: Uses `langid` to detect the language of thinking blocks + +- **Configurable target language**: While this app defaults to Japanese (`ja`), the `LanguageReward` can be configured for any ISO 639-1 language code + +## Requirements + +Before running this app, install the required language detection library: + +```bash +pip install langid +``` + +## Usage + +```bash +python -m sandbox.grpo_language.main --config sandbox/grpo_language/qwen3_1_7b.yaml +``` + +## How It Works + +1. The model receives a math problem and is instructed to use `` tags for reasoning +2. During training, the model generates responses with thinking blocks +3. Three rewards are computed: + - Math correctness (did it get the right answer?) + - Thinking usage (did it use thinking tags properly?) + - Language usage (did it think in Japanese?) +4. The model is trained to maximize all three rewards + +## Configuration + +The target language is hardcoded as Japanese in `main.py` (line 321): + +```python +LanguageReward(target_language="ja") +``` + +To use a different language, modify this line with the appropriate ISO 639-1 code: +- English: `"en"` +- Chinese: `"zh"` +- Spanish: `"es"` +- French: `"fr"` +- etc. + +## Expected Behavior + +Over the course of training, the model should learn to: +1. Solve math problems correctly +2. Use `` tags for its reasoning +3. Write its thinking in Japanese (or the configured target language) + +## Metrics + +The following metrics are logged to W&B: +- `reward/evaluate_response/avg_LanguageReward_reward`: Average language reward +- `reward/evaluate_response/avg_MathReward_reward`: Average math reward +- `reward/evaluate_response/avg_ThinkingReward_reward`: Average thinking reward +- `reward/evaluate_response/avg_total_reward`: Average of all rewards + +## Differences from Standard GRPO + +This is a modified version of `apps/grpo/main.py` with: +1. Added import: `from forge.data.rewards import LanguageReward` +2. Modified reward functions list to include `LanguageReward(target_language="ja")` +3. Updated config to use different W&B group name + +All other training logic remains the same. diff --git a/sandbox/grpo_language/main.py b/sandbox/grpo_language/main.py new file mode 100644 index 000000000..2972eea09 --- /dev/null +++ b/sandbox/grpo_language/main.py @@ -0,0 +1,498 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Usage: python -m sandbox.grpo_language.main --config sandbox/grpo_language/qwen3_1_7b.yaml + +import asyncio +import time +import uuid +from dataclasses import dataclass +from typing import Any, Callable + +import torch +import torch.nn.functional as F +import torchstore as ts +from datasets import load_dataset +from forge.actors._torchstore_utils import ( + get_dcp_whole_state_dict_key, + get_param_prefix, +) +from forge.actors.generator import Generator +from forge.actors.reference_model import ReferenceModel +from forge.actors.replay_buffer import ReplayBuffer +from forge.actors.trainer import RLTrainer +from forge.controller.actor import ForgeActor +from forge.controller.provisioner import init_provisioner, shutdown +from forge.data.rewards import LanguageReward, MathReward, ThinkingReward +from forge.data_models.completion import Completion +from forge.observability.metric_actors import get_or_create_metric_logger +from forge.observability.metrics import record_metric, Reduce +from forge.observability.perf_tracker import Tracer + +from forge.types import LauncherConfig, ProvisionerConfig +from forge.util.config import parse +from forge.util.ops import compute_logprobs +from monarch.actor import endpoint +from omegaconf import DictConfig +from vllm.transformers_utils.tokenizer import get_tokenizer + + +@dataclass +class Episode: + episode_id: str + pad_id: int + request_len: int + response_len: int + target: Any | None = None + # Processed data + completion: Completion | None = None + ref_logprobs: torch.Tensor | None = None + reward: float | None = None + advantage: float | None = None + + @property + def policy_version(self) -> int | None: + return self.completion.generator_version + + @property + def request_tensor(self) -> torch.Tensor: + tensor: torch.Tensor = self.completion.prompt_ids.to(torch.long) + if tensor.shape[0] < self.request_len: # left pad + diff = self.request_len - tensor.shape[0] + tensor = F.pad(tensor, (diff, 0), value=self.pad_id) + return tensor + + @property + def response_tensor(self) -> torch.Tensor: + tensor: torch.Tensor = self.completion.token_ids.to(torch.long) + if tensor.shape[0] < self.response_len: # right pad + diff = self.response_len - tensor.shape[0] + tensor = F.pad(tensor, (0, diff), value=self.pad_id) + return tensor + + +# Represents the group (G) of episodes in GRPO +Group = list[Episode] + +# Represents the Policy Model to collect data from +Policy = Generator + + +def collate( + batches: list[Group], +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """ + Collates a list of batches into a single batch of inputs and targets. + Each batch is a list of episodes, and each episode is a dict of tensors. + """ + inputs = [] + targets = [] + for batch in batches: + request = [e.request_tensor for e in batch] + request = torch.stack(request) # [b x s] + + response = [e.response_tensor for e in batch] + response = torch.stack(response) # [b x s] + + ref_logprobs = [e.ref_logprobs for e in batch] + ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s] + + advantages = [e.advantage for e in batch] + advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1] + + pad_id = batch[0].pad_id + mask = response != pad_id + + input = {"tokens": torch.cat([request, response], dim=1)} + target = { + "response": response, + "ref_logprobs": ref_logprobs, + "advantages": advantages, + "padding_mask": mask, + } + inputs.append(input) + targets.append(target) + return inputs, targets + + +# Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss` +def simple_grpo_loss( + logits: torch.Tensor, + response: torch.Tensor, + ref_logprobs: torch.Tensor, + advantages: torch.Tensor, + padding_mask: torch.Tensor, + beta: float = 0.1, +) -> torch.Tensor: + logprobs: torch.Tensor = compute_logprobs(logits, response) + kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 + per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages + per_token_loss = -(per_token_policy_loss - beta * kl) + loss = ( + ((per_token_loss * padding_mask).sum(dim=1)) + / (padding_mask.sum(dim=1).clamp(min=1.0)) + ).mean() + return loss + + +@dataclass +class RewardActor(ForgeActor): + + reward_functions: list[Callable] + + @endpoint + async def evaluate_response(self, prompt: str, response: str, target: str) -> float: + total_rewards = 0.0 + for reward_fn in self.reward_functions: + reward = reward_fn(prompt, response, target) + total_rewards += reward + + # Get a name for the reward function (works for classes, functions, lambdas) + reward_fn_name = getattr( + reward_fn, "__name__", reward_fn.__class__.__name__ + ) + # per function reward + record_metric( + f"reward/evaluate_response/sum_{reward_fn_name}_reward", + reward, + Reduce.SUM, + ) + record_metric( + f"reward/evaluate_response/avg_{reward_fn_name}_reward", + reward, + Reduce.MEAN, + ) + record_metric( + f"reward/evaluate_response/std_{reward_fn_name}_reward", + reward, + Reduce.STD, + ) + + record_metric( + "reward/evaluate_response/avg_total_reward", + reward, + Reduce.MEAN, + ) + + record_metric( + f"reward/evaluate_response/count_{reward_fn_name}_calls", + 1, + Reduce.SUM, + ) + + avg_reward = total_rewards / len(self.reward_functions) + return avg_reward + + +@dataclass +class ComputeAdvantages(ForgeActor): + @endpoint + async def compute(self, group: Group) -> list[float]: + # TODO: add batch processing + rewards = torch.tensor([[e.reward for e in group]]) + mean = rewards.mean(1, keepdim=True) + std = rewards.std(1, keepdim=True) + advantages = (rewards - mean) / (std + 1e-4) + return advantages.squeeze(0).tolist() + + +@dataclass +class DatasetActor(ForgeActor): + """Actor wrapper for HuggingFace dataset to provide async interface.""" + + path: str = "openai/gsm8k" + revision: str = "main" + data_split: str = "train" + streaming: bool = True + model: str = "Qwen/Qwen3-1.7B" + + @endpoint + def setup(self): + self._tokenizer = get_tokenizer(self.model) + + def gsm8k_transform(sample): + system_prompt = """ + Put all your scratchpad work between and tags. + Your final answer should be between and tags otherwise it will not be scored. + """ + request: str = sample["question"] + as_chat = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": request}, + ] + formatted_request = self._tokenizer.apply_chat_template( + as_chat, + tokenize=False, + add_generation_prompt=True, + ) + target: str = sample["answer"] + formatted_target = target.split("#### ")[1] + return {"request": formatted_request, "target": formatted_target} + + ds = load_dataset( + self.path, self.revision, split=self.data_split, streaming=self.streaming + ) + ds = ds.map(gsm8k_transform) + ds = ds.shuffle() + self._iterator = iter(ds) + + @endpoint + async def sample(self) -> dict[str, str] | None: + try: + sample = next(self._iterator) + + record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM) + record_metric( + "dataset/sample/avg_sample_len", + len(sample["request"]), + Reduce.MEAN, + ) + + return sample + except StopIteration: + return None + + @endpoint + async def pad_token(self): + return self._tokenizer.pad_token_id + + +async def drop_weights(version: int): + print(f"Dropping weights @ version {version}") + start_time = time.perf_counter() + prefix = get_param_prefix(version) + matching_keys = await ts.keys(prefix) + # TODO: once we have something like `get_meta()` in torchstore, we can just + # query the type of the object instead of relying on keys. + dcp_key = get_dcp_whole_state_dict_key(version) + if dcp_key in matching_keys: + dcp_handle = await ts.get(dcp_key) + dcp_handle.drop() + for key in matching_keys: + await ts.delete(key) + elapsed = time.perf_counter() - start_time + print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds") + + +async def main(cfg: DictConfig): + """Main GRPO training loop with rollout and training processes.""" + group_size = cfg.group_size + max_req_tokens = cfg.max_req_tokens + max_res_tokens = cfg.max_res_tokens + + # ---- Global setups ---- # + provisioner = None + if cfg.get("provisioner", None) is not None: + provisioner = await init_provisioner( + ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) + ) + else: + provisioner = await init_provisioner() + + metric_logging_cfg = cfg.get("metric_logging", {}) + mlogger = await get_or_create_metric_logger(process_name="Controller") + await mlogger.init_backends.call_one(metric_logging_cfg) + + # ---- Setup services ---- # + + ( + dataloader, + policy, + trainer, + replay_buffer, + compute_advantages, + ref_model, + reward_actor, + ) = await asyncio.gather( + DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset), + Policy.options(**cfg.services.policy).as_service(**cfg.policy), + RLTrainer.options(**cfg.actors.trainer).as_actor( + **cfg.trainer, loss=simple_grpo_loss + ), + ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor( + **cfg.replay_buffer, collate=collate + ), + ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(), + ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model), + RewardActor.options(**cfg.services.reward_actor).as_service( + reward_functions=[ + MathReward(), + ThinkingReward(), + LanguageReward(target_language="ja"), # Japanese language reward + ] + ), + ) + + # Set max_steps to the configured value, or -1 if not specified or Null + max_steps = cfg.trainer.training.steps or -1 + + print("All services initialized successfully!") + shutdown_event = asyncio.Event() + # Here we spawn a torchstore storage volume per trainer process. + # We initialize after service initialization because torchstore currently + # requires access to the underlying proc meshes in the local rank strategy. + # We should be able to hide this in the future. + # TODO: support multiple host meshes + trainer_num_procs = cfg.actors.trainer["procs"] + trainer_host_mesh_name = cfg.actors.trainer["mesh_name"] + trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name) + await ts.initialize( + mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}), + strategy=ts.LocalRankStrategy(), + ) + print("Torchstore successfully initialized with local rank strategy") + + # ---- Core RL loops ---- # + async def continuous_rollouts(): + rollout_count = 0 + pad_id = await dataloader.pad_token.call_one() + while not shutdown_event.is_set(): + t = Tracer("main_perf/continuous_rollouts") + t.start() + sample = await dataloader.sample.call_one() + if sample is None: + print("Dataloader is empty, exiting continuous rollout") + return + + t.step("data_loading") + + prompt, target = sample["request"], sample["target"] + responses: list[Completion] = await policy.generate.route(prompt) + t.step("policy_generation") + + # Construct episodes and calculate rewards + episodes = [] + input_ids = torch.ones( + (group_size, max_req_tokens + max_res_tokens), + dtype=torch.long, + ) + for i, response in enumerate(responses): + episode = Episode( + episode_id=str(uuid.uuid4()), + pad_id=pad_id, + request_len=max_req_tokens, + response_len=max_res_tokens, + target=target, + completion=response, + ) + episode.reward = await reward_actor.evaluate_response.route( + prompt=prompt, response=response.text, target=target + ) + episodes.append(episode) + + # Build input_ids for reference logprobs + input_ids[i, :max_req_tokens] = episode.request_tensor + input_ids[i, max_req_tokens:] = episode.response_tensor + + t.step("reward_evaluation") + + ref_logprobs = await ref_model.forward.route( + input_ids, max_req_tokens, return_logprobs=True + ) + t.step("reference_model_calculate_logprobs") + + for i, episode in enumerate(episodes): + episode.ref_logprobs = ref_logprobs[i] + del ref_logprobs, input_ids + + advantages = await compute_advantages.compute.call_one(episodes) + for episode, advantage in zip(episodes, advantages): + episode.advantage = advantage + await replay_buffer.add.call_one(episode) + + rollout_count += 1 + record_metric( + "main/continuous_rollouts/count_rollout_iterations", 1, Reduce.SUM + ) + t.stop() + + async def continuous_training(): + training_step = 0 + restart_tracer = True # Flag to control when to restart tracer + + while max_steps == -1 or training_step < max_steps: + # Restart tracer when needed (initial start or after completing a training step) + # Otherwise, we cannot measure time waiting for buffer + if restart_tracer: + t = Tracer("main_perf/continuous_training") + t.start() + restart_tracer = False + + batch = await replay_buffer.sample.call_one( + curr_policy_version=training_step + ) + if batch is None: + await asyncio.sleep(0.1) + else: + t.step("waiting_for_buffer") + + inputs, targets = batch + await trainer.train_step.call(inputs, targets) + training_step += 1 + t.step("train_step") + + await trainer.push_weights.call(training_step) + t.step("push_weights") + + await policy.update_weights.fanout(training_step) + t.step("update_weights") + + if training_step >= 2: + await drop_weights(training_step - 1) + t.step("drop_weights") + + t.stop() + restart_tracer = True + + # Flush metrics every training step to WandB + await mlogger.flush.call_one(training_step) + + print( + f"Reached training limit ({max_steps} steps). Exiting continuous_training loop." + ) + + num_rollout_threads = cfg.get("rollout_threads", 1) + num_training_threads = cfg.get("training_threads", 1) + print( + f"Starting GRPO with {num_rollout_threads} rollout threads, {num_training_threads} training threads" + ) + rollout_tasks = [ + asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads) + ] + training_task = asyncio.create_task(continuous_training()) + + try: + await training_task + except KeyboardInterrupt: + print("Training interrupted by user") + finally: + print("Shutting down... (this may take a few seconds)") + shutdown_event.set() + + try: + # Give rollouts up to 5s to finish naturally + await asyncio.wait_for( + asyncio.gather(*rollout_tasks, return_exceptions=True), + timeout=5, + ) + except asyncio.TimeoutError: + print("Timeout waiting for rollouts; forcing cancellation...") + for t in rollout_tasks: + t.cancel() + await asyncio.gather(*rollout_tasks, return_exceptions=True) + + training_task.cancel() + + await shutdown() + + +if __name__ == "__main__": + + @parse + def _main(cfg): + asyncio.run(main(cfg)) + + _main() # @parse grabs the cfg from CLI diff --git a/sandbox/grpo_language/qwen3_1_7b.yaml b/sandbox/grpo_language/qwen3_1_7b.yaml new file mode 100644 index 000000000..e6715a219 --- /dev/null +++ b/sandbox/grpo_language/qwen3_1_7b.yaml @@ -0,0 +1,151 @@ +# Grouped Relative Policy Optimization (GRPO) with Language Reward +# >>> python -m sandbox.grpo_language.main --config sandbox/grpo_language/qwen3_1_7b.yaml + +# Global configuration +group_size: 8 +local_batch_size: 16 # per-device batch size +max_req_tokens: 1024 +max_res_tokens: 1024 +model: "Qwen/Qwen3-1.7B" +off_by_n: 1 # Off by one by default + +# Main loop configuration +rollout_threads: 1 # Recommended to set equal to policy.num_replicas + + +# Observability configuration +metric_logging: + wandb: + project: grpo-training + group: grpo_language_exp_${oc.env:USER} + logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce + console: + logging_mode: global_reduce + +# Dataset configuration +dataset: + path: "openai/gsm8k" + revision: "main" + data_split: "train" + streaming: true + model: ${model} + +# Policy configuration +policy: + engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs + model: ${model} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + enforce_eager: false + sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams + n: ${group_size} + max_tokens: ${max_res_tokens} + temperature: 1.0 + top_p: 1.0 + +# Trainer configuration +trainer: + model: + name: qwen3 + flavor: 1.7B + hf_assets_path: hf://${model} + optimizer: + name: AdamW + lr: 1e-5 + eps: 1e-8 + lr_scheduler: + warmup_steps: 1 + training: + local_batch_size: ${local_batch_size} + seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens + max_norm: 1.0 + steps: 1000000 + dtype: bfloat16 + gc_freq: 1 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + disable_loss_parallel: true + checkpoint: + enable: true + folder: ./checkpoint # The folder to save checkpoints to. + initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists. + initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo + last_save_in_hf: true + interval: 500 + async_mode: "disabled" + activation_checkpoint: + mode: selective + selective_ac_option: op + +# Replay buffer configuration +replay_buffer: + batch_size: ${local_batch_size} + max_policy_age: ${off_by_n} + dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree + +# Reference model configuration +ref_model: + model: + name: qwen3 + flavor: 1.7B + hf_assets_path: hf://${model} + training: + seq_len: ${trainer.training.seq_len} + dtype: bfloat16 + gc_freq: 1 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + checkpoint: + enable: true + initial_load_path: hf://${model} + initial_load_in_hf: true + +# All resource allocations +services: + policy: + procs: ${policy.engine_args.tensor_parallel_size} + num_replicas: 1 + mesh_name: policy + with_gpus: true + ref_model: + procs: 1 + num_replicas: 1 + mesh_name: ref_model + with_gpus: true + reward_actor: + procs: 1 + num_replicas: 1 + mesh_name: reward_actor + with_gpus: false + +actors: + dataset: + procs: 1 + with_gpus: false + mesh_name: dataset + trainer: + procs: 1 + with_gpus: true + mesh_name: trainer + replay_buffer: + procs: 1 + with_gpus: false + mesh_name: replay_buffer + compute_advantages: + procs: 1 + with_gpus: false + mesh_name: compute_advantages diff --git a/src/forge/data/rewards.py b/src/forge/data/rewards.py index 23a0002df..e937d5772 100644 --- a/src/forge/data/rewards.py +++ b/src/forge/data/rewards.py @@ -80,3 +80,82 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo elif has_attempt: return self.partial_reward return 0.0 + + +class LanguageReward: + """Reward class for evaluating the language used in tags. + + This reward uses langid to detect the language of text within thinking blocks + and rewards responses that use the target language. + + Args: + target_language: ISO 639-1 language code (e.g., 'en', 'ja', 'zh', 'es') + full_reward: Reward when detected language matches target + no_match_reward: Reward when detected language doesn't match target + + Note: Requires langid to be installed. Install with: pip install langid + """ + + def __init__( + self, + target_language: str = "en", + full_reward: float = 1.0, + no_match_reward: float = 0.0, + ): + self.target_language = target_language + self.full_reward = full_reward + self.no_match_reward = no_match_reward + self._THINK_BLOCK_RE = re.compile( + r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", re.IGNORECASE | re.DOTALL + ) + + # Lazy import langid with helpful error message + try: + import langid + + self._langid = langid + except ImportError: + raise ImportError( + "langid is required for LanguageReward but is not installed. " + "Please install it with: pip install langid" + ) from None + + def __call__(self, prompt: str, response: str, target: str | None = None) -> float: + """Compute language reward based on thinking block content. + + Args: + prompt: The input prompt (unused but kept for signature consistency) + response: The model response containing tags + target: Optional target string (unused but kept for signature consistency) + + Returns: + full_reward if detected language matches target_language and format is correct, + no_match_reward otherwise (including when format is wrong or no thinking block) + """ + if not response: + return self.no_match_reward + + # Extract all thinking blocks + matches = self._THINK_BLOCK_RE.findall(response) + + # Return 0 reward if format is wrong (0 or multiple thinking blocks) + if len(matches) != 1: + return self.no_match_reward + + # Get the single thinking block content + thinking_content = matches[0] + + # Remove extra whitespace + thinking_content = re.sub(r"\s+", " ", thinking_content).strip() + + if not thinking_content: + return self.no_match_reward + + # Detect language using langid + detected_lang, confidence = self._langid.classify(thinking_content) + + # Return full reward if language matches target + if detected_lang == self.target_language: + return self.full_reward + + return self.no_match_reward diff --git a/tests/unit_tests/rl/test_language_reward.py b/tests/unit_tests/rl/test_language_reward.py new file mode 100644 index 000000000..67379a9b9 --- /dev/null +++ b/tests/unit_tests/rl/test_language_reward.py @@ -0,0 +1,287 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import unittest +from unittest.mock import patch + + +class TestLanguageReward(unittest.TestCase): + def setUp(self): + """Set up test fixtures before each test method.""" + # Import after patching to avoid ImportError + from forge.data.rewards import LanguageReward + + self.LanguageReward = LanguageReward + self.reward_en = LanguageReward(target_language="en") + self.reward_ja = LanguageReward(target_language="ja") + self.custom_reward = LanguageReward( + target_language="ja", full_reward=0.9, no_match_reward=0.1 + ) + + def test_init_default_values(self): + """Test LanguageReward initialization with default values.""" + reward = self.LanguageReward() + self.assertEqual(reward.target_language, "en") + self.assertEqual(reward.full_reward, 1.0) + self.assertEqual(reward.no_match_reward, 0.0) + + def test_init_custom_values(self): + """Test LanguageReward initialization with custom values.""" + reward = self.LanguageReward( + target_language="ja", full_reward=0.9, no_match_reward=0.1 + ) + self.assertEqual(reward.target_language, "ja") + self.assertEqual(reward.full_reward, 0.9) + self.assertEqual(reward.no_match_reward, 0.1) + + def test_init_missing_langid(self): + """Test LanguageReward initialization without langid installed.""" + # Remove langid from modules if it exists + langid_module = sys.modules.get("langid") + if "langid" in sys.modules: + del sys.modules["langid"] + + with patch.dict("sys.modules", {"langid": None}): + with self.assertRaises(ImportError) as context: + # Re-import to trigger the ImportError + import importlib + + import forge.data.rewards + + importlib.reload(forge.data.rewards) + forge.data.rewards.LanguageReward() + + self.assertIn("langid is required", str(context.exception)) + self.assertIn("pip install langid", str(context.exception)) + + # Restore langid module if it existed + if langid_module is not None: + sys.modules["langid"] = langid_module + + def test_regex_pattern(self): + """Test that regex pattern is compiled correctly.""" + reward = self.LanguageReward() + self.assertIsNotNone(reward._THINK_BLOCK_RE) + + def test_call_with_english_thinking(self): + """Test __call__ with English text in thinking blocks.""" + response = "This is English reasoning about math problems." + result = self.reward_en("prompt", response) + self.assertEqual(result, 1.0) + + def test_call_with_japanese_thinking(self): + """Test __call__ with Japanese text in thinking blocks.""" + response = "これは日本語で考えています。数学の問題を解きます。" + result = self.reward_ja("prompt", response) + self.assertEqual(result, 1.0) + + # English reward should give no_match_reward for Japanese text + result = self.reward_en("prompt", response) + self.assertEqual(result, 0.0) + + def test_call_with_chinese_thinking(self): + """Test __call__ with Chinese text in thinking blocks.""" + response = "这是中文思考。我们需要解决这个数学问题。" + reward_zh = self.LanguageReward(target_language="zh") + result = reward_zh("prompt", response) + # langid should detect this as Chinese (zh) + self.assertEqual(result, 1.0) + + def test_call_with_spanish_thinking(self): + """Test __call__ with Spanish text in thinking blocks.""" + response = "Este es un razonamiento en español sobre problemas matemáticos." + reward_es = self.LanguageReward(target_language="es") + result = reward_es("prompt", response) + # langid should detect this as Spanish (es) + self.assertEqual(result, 1.0) + + def test_call_language_mismatch(self): + """Test __call__ when detected language doesn't match target.""" + # Japanese reward with English text + response = "This is English reasoning." + result = self.reward_ja("prompt", response) + self.assertEqual(result, 0.0) + + # English reward with Japanese text + response = "これは日本語です。" + result = self.reward_en("prompt", response) + self.assertEqual(result, 0.0) + + def test_call_with_no_thinking_tags(self): + """Test __call__ with response containing no thinking tags.""" + result = self.reward_en( + "prompt", "This is just a regular response without any thinking tags." + ) + self.assertEqual(result, 0.0) + + def test_call_with_empty_thinking_block(self): + """Test __call__ with empty thinking block.""" + result = self.reward_en("prompt", "") + self.assertEqual(result, 0.0) + + def test_call_with_whitespace_only_thinking_block(self): + """Test __call__ with whitespace-only thinking block.""" + result = self.reward_en("prompt", " \n \t ") + self.assertEqual(result, 0.0) + + def test_call_case_insensitive(self): + """Test __call__ is case insensitive for thinking tags.""" + response = "This is English reasoning." + result = self.reward_en("prompt", response) + self.assertEqual(result, 1.0) + + response = "This is English reasoning." + result = self.reward_en("prompt", response) + self.assertEqual(result, 1.0) + + def test_call_with_whitespace_in_tags(self): + """Test __call__ with whitespace in thinking tags.""" + response = "< think >This is English reasoning." + result = self.reward_en("prompt", response) + self.assertEqual(result, 1.0) + + def test_call_multiple_thinking_blocks(self): + """Test __call__ with multiple thinking blocks (wrong format).""" + response = """ + First thought in English. + Some text in between. + Second thought also in English. + """ + result = self.reward_en("prompt", response) + # Multiple blocks = wrong format, should return 0 + self.assertEqual(result, 0.0) + + def test_call_multiple_thinking_blocks_mixed_languages(self): + """Test __call__ with multiple thinking blocks in different languages (wrong format).""" + response = """ + First thought in English with lots of content here. + これは短い日本語。 + """ + result = self.reward_en("prompt", response) + # Multiple blocks = wrong format, should return 0 + self.assertEqual(result, 0.0) + + def test_call_multiline_thinking_block(self): + """Test __call__ with multiline thinking blocks.""" + response = """ + This is a multiline + thinking block with + lots of English content + about solving problems + """ + result = self.reward_en("prompt", response) + self.assertEqual(result, 1.0) + + def test_call_empty_response(self): + """Test __call__ with empty response.""" + result = self.reward_en("prompt", "") + self.assertEqual(result, 0.0) + + def test_call_none_response(self): + """Test __call__ with None response.""" + result = self.reward_en("prompt", None) + self.assertEqual(result, 0.0) + + def test_call_with_target_parameter(self): + """Test __call__ with target parameter (should be ignored).""" + response = "This is English reasoning." + result = self.reward_en("prompt", response, target="some target") + self.assertEqual(result, 1.0) + + result = self.reward_en("prompt", "no tags", target="some target") + self.assertEqual(result, 0.0) + + def test_call_custom_reward_values(self): + """Test __call__ with custom reward values.""" + response_ja = "これは日本語です。" + response_en = "This is English." + response_none = "no thinking tags" + + # Test custom full reward + self.assertEqual(self.custom_reward("prompt", response_ja), 0.9) + # Test custom no_match reward + self.assertEqual(self.custom_reward("prompt", response_en), 0.1) + # Test no tags + self.assertEqual(self.custom_reward("prompt", response_none), 0.1) + + def test_call_zero_custom_values(self): + """Test __call__ with zero custom values.""" + zero_reward = self.LanguageReward( + target_language="en", full_reward=0.0, no_match_reward=0.0 + ) + result = zero_reward("prompt", "This is English.") + self.assertEqual(result, 0.0) + + def test_call_with_special_characters(self): + """Test __call__ with special characters in thinking blocks.""" + response = ( + "English with special chars: @#$%^&*()_+-=[]{}|;':\",./<>?`~" + ) + result = self.reward_en("prompt", response) + self.assertEqual(result, 1.0) + + def test_call_with_mixed_content_outside_tags(self): + """Test __call__ with mixed language content outside thinking tags.""" + # Content outside think tags should be ignored + response = """ + これは日本語のテキストです。 + But this is English reasoning inside the tags. + もっと日本語のテキスト。 + """ + result = self.reward_en("prompt", response) + # Should detect English from thinking block only + self.assertEqual(result, 1.0) + + def test_call_with_numbers_and_symbols(self): + """Test __call__ with thinking blocks containing mostly numbers.""" + response = "Calculate: 2 + 2 = 4, then 4 * 3 = 12" + result = self.reward_en("prompt", response) + # Should still detect as English due to words like "Calculate" and "then" + self.assertEqual(result, 1.0) + + def test_call_very_long_thinking_block(self): + """Test __call__ with very long thinking blocks.""" + long_content = "This is English content. " * 1000 + result = self.reward_en("prompt", f"{long_content}") + self.assertEqual(result, 1.0) + + def test_call_with_code_in_thinking(self): + """Test __call__ with code snippets in thinking blocks.""" + response = """ + Let me write some Python code to solve this: + def calculate(x): + return x * 2 + The function doubles the input value. + """ + result = self.reward_en("prompt", response) + # Should detect as English due to surrounding text + self.assertEqual(result, 1.0) + + def test_different_language_codes(self): + """Test __call__ with various ISO 639-1 language codes.""" + # Test a few common languages + languages = { + "fr": "Ceci est un texte en français avec beaucoup de contenu.", + "de": "Dies ist ein deutscher Text mit viel Inhalt.", + "it": "Questo è un testo italiano con molto contenuto.", + "pt": "Este é um texto em português com muito conteúdo.", + } + + for lang_code, text in languages.items(): + reward = self.LanguageReward(target_language=lang_code) + response = f"{text}" + result = reward("prompt", response) + # langid should detect these correctly + self.assertEqual( + result, + 1.0, + f"Failed to detect {lang_code} language: '{text[:50]}...'", + ) + + +if __name__ == "__main__": + unittest.main() From b12ed15a7828f84131e54e3e925bb4d2424f7d90 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 31 Oct 2025 13:15:40 -0700 Subject: [PATCH 02/24] Update system prompt to instruct model to think in Japanese --- sandbox/grpo_language/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sandbox/grpo_language/main.py b/sandbox/grpo_language/main.py index 2972eea09..75219c39c 100644 --- a/sandbox/grpo_language/main.py +++ b/sandbox/grpo_language/main.py @@ -215,7 +215,7 @@ def setup(self): def gsm8k_transform(sample): system_prompt = """ - Put all your scratchpad work between and tags. + Put all your scratchpad work between and tags. You must think in Japanese inside the tags. Your final answer should be between and tags otherwise it will not be scored. """ request: str = sample["question"] From b15f17134afe8af7294bf36d2b1c3420819c17b0 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 31 Oct 2025 13:33:12 -0700 Subject: [PATCH 03/24] Add fallback reward for correct language without thinking blocks - Add fallback_reward parameter (default 0.2) - If no blocks found, check if response text is in target language - Reward structure: * full_reward (1.0): Single block + correct language * partial_reward (0.5): Multiple blocks + correct language * fallback_reward (0.2): No blocks + correct language in response text * no_match_reward (0.0): Wrong language - Update all tests to reflect new behavior (29 tests passing) --- src/forge/data/rewards.py | 48 +++++++++++---- tests/unit_tests/rl/test_language_reward.py | 65 +++++++++++++++------ 2 files changed, 86 insertions(+), 27 deletions(-) diff --git a/src/forge/data/rewards.py b/src/forge/data/rewards.py index e937d5772..64a6c856e 100644 --- a/src/forge/data/rewards.py +++ b/src/forge/data/rewards.py @@ -90,8 +90,10 @@ class LanguageReward: Args: target_language: ISO 639-1 language code (e.g., 'en', 'ja', 'zh', 'es') - full_reward: Reward when detected language matches target - no_match_reward: Reward when detected language doesn't match target + full_reward: Reward when language matches and format is correct (single block) + partial_reward: Reward when language matches but format is wrong (multiple blocks) + fallback_reward: Reward when no valid blocks but response text is in target language + no_match_reward: Reward when language doesn't match Note: Requires langid to be installed. Install with: pip install langid """ @@ -100,10 +102,14 @@ def __init__( self, target_language: str = "en", full_reward: float = 1.0, + partial_reward: float = 0.5, + fallback_reward: float = 0.2, no_match_reward: float = 0.0, ): self.target_language = target_language self.full_reward = full_reward + self.partial_reward = partial_reward + self.fallback_reward = fallback_reward self.no_match_reward = no_match_reward self._THINK_BLOCK_RE = re.compile( r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", re.IGNORECASE | re.DOTALL @@ -129,8 +135,10 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo target: Optional target string (unused but kept for signature consistency) Returns: - full_reward if detected language matches target_language and format is correct, - no_match_reward otherwise (including when format is wrong or no thinking block) + full_reward if language matches and exactly one thinking block is found, + partial_reward if language matches but multiple thinking blocks found, + fallback_reward if no valid blocks but response text is in target language, + no_match_reward otherwise (wrong language) """ if not response: return self.no_match_reward @@ -138,12 +146,27 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo # Extract all thinking blocks matches = self._THINK_BLOCK_RE.findall(response) - # Return 0 reward if format is wrong (0 or multiple thinking blocks) - if len(matches) != 1: + # If no thinking blocks found, check if response text is in target language + if len(matches) == 0: + # Remove any partial tags that might exist + response_text = re.sub( + r"<\s*/?\s*think\s*>", "", response, flags=re.IGNORECASE + ).strip() + + if not response_text: + return self.no_match_reward + + # Detect language of general response + detected_lang, confidence = self._langid.classify(response_text) + + # Give fallback reward if response is in target language + if detected_lang == self.target_language: + return self.fallback_reward + return self.no_match_reward - # Get the single thinking block content - thinking_content = matches[0] + # Concatenate all thinking blocks for language detection + thinking_content = " ".join(matches) # Remove extra whitespace thinking_content = re.sub(r"\s+", " ", thinking_content).strip() @@ -154,8 +177,13 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo # Detect language using langid detected_lang, confidence = self._langid.classify(thinking_content) - # Return full reward if language matches target + # Check if language matches target if detected_lang == self.target_language: - return self.full_reward + # Full reward for correct format (single block) + if len(matches) == 1: + return self.full_reward + # Partial reward for wrong format (multiple blocks) but correct language + else: + return self.partial_reward return self.no_match_reward diff --git a/tests/unit_tests/rl/test_language_reward.py b/tests/unit_tests/rl/test_language_reward.py index 67379a9b9..b88846fcf 100644 --- a/tests/unit_tests/rl/test_language_reward.py +++ b/tests/unit_tests/rl/test_language_reward.py @@ -19,7 +19,11 @@ def setUp(self): self.reward_en = LanguageReward(target_language="en") self.reward_ja = LanguageReward(target_language="ja") self.custom_reward = LanguageReward( - target_language="ja", full_reward=0.9, no_match_reward=0.1 + target_language="ja", + full_reward=0.9, + partial_reward=0.6, + fallback_reward=0.3, + no_match_reward=0.1, ) def test_init_default_values(self): @@ -27,15 +31,23 @@ def test_init_default_values(self): reward = self.LanguageReward() self.assertEqual(reward.target_language, "en") self.assertEqual(reward.full_reward, 1.0) + self.assertEqual(reward.partial_reward, 0.5) + self.assertEqual(reward.fallback_reward, 0.2) self.assertEqual(reward.no_match_reward, 0.0) def test_init_custom_values(self): """Test LanguageReward initialization with custom values.""" reward = self.LanguageReward( - target_language="ja", full_reward=0.9, no_match_reward=0.1 + target_language="ja", + full_reward=0.9, + partial_reward=0.6, + fallback_reward=0.3, + no_match_reward=0.1, ) self.assertEqual(reward.target_language, "ja") self.assertEqual(reward.full_reward, 0.9) + self.assertEqual(reward.partial_reward, 0.6) + self.assertEqual(reward.fallback_reward, 0.3) self.assertEqual(reward.no_match_reward, 0.1) def test_init_missing_langid(self): @@ -112,10 +124,17 @@ def test_call_language_mismatch(self): self.assertEqual(result, 0.0) def test_call_with_no_thinking_tags(self): - """Test __call__ with response containing no thinking tags.""" + """Test __call__ with response containing no thinking tags but correct language.""" result = self.reward_en( "prompt", "This is just a regular response without any thinking tags." ) + # No thinking blocks but response is in English, should get fallback reward + self.assertEqual(result, 0.2) + + def test_call_with_no_thinking_tags_wrong_language(self): + """Test __call__ with response containing no thinking tags and wrong language.""" + result = self.reward_en("prompt", "これは日本語の応答です。タグはありません。") + # No thinking blocks and wrong language, should get no_match_reward self.assertEqual(result, 0.0) def test_call_with_empty_thinking_block(self): @@ -145,15 +164,15 @@ def test_call_with_whitespace_in_tags(self): self.assertEqual(result, 1.0) def test_call_multiple_thinking_blocks(self): - """Test __call__ with multiple thinking blocks (wrong format).""" + """Test __call__ with multiple thinking blocks (wrong format but correct language).""" response = """ First thought in English. Some text in between. Second thought also in English. """ result = self.reward_en("prompt", response) - # Multiple blocks = wrong format, should return 0 - self.assertEqual(result, 0.0) + # Multiple blocks = wrong format, but language is correct, should return partial_reward + self.assertEqual(result, 0.5) def test_call_multiple_thinking_blocks_mixed_languages(self): """Test __call__ with multiple thinking blocks in different languages (wrong format).""" @@ -162,8 +181,9 @@ def test_call_multiple_thinking_blocks_mixed_languages(self): これは短い日本語。 """ result = self.reward_en("prompt", response) - # Multiple blocks = wrong format, should return 0 - self.assertEqual(result, 0.0) + # Multiple blocks with mixed languages - langid will detect dominant language + # Should return either partial_reward (if detects English) or no_match_reward (if detects Japanese) + self.assertIn(result, [0.0, 0.5]) def test_call_multiline_thinking_block(self): """Test __call__ with multiline thinking blocks.""" @@ -192,20 +212,31 @@ def test_call_with_target_parameter(self): result = self.reward_en("prompt", response, target="some target") self.assertEqual(result, 1.0) - result = self.reward_en("prompt", "no tags", target="some target") - self.assertEqual(result, 0.0) + # Longer English text without tags should get fallback reward + result = self.reward_en( + "prompt", + "This is a response without thinking tags but in English language.", + target="some target", + ) + self.assertEqual(result, 0.2) def test_call_custom_reward_values(self): """Test __call__ with custom reward values.""" - response_ja = "これは日本語です。" + response_ja_single = "これは日本語です。" + response_ja_multiple = "最初の考え。次の考え。" + response_ja_no_tags = "これはタグなしの日本語です。" response_en = "This is English." - response_none = "no thinking tags" - - # Test custom full reward - self.assertEqual(self.custom_reward("prompt", response_ja), 0.9) - # Test custom no_match reward + response_none = "" + + # Test custom full reward (single block, correct language) + self.assertEqual(self.custom_reward("prompt", response_ja_single), 0.9) + # Test custom partial reward (multiple blocks, correct language) + self.assertEqual(self.custom_reward("prompt", response_ja_multiple), 0.6) + # Test custom fallback reward (no blocks, correct language) + self.assertEqual(self.custom_reward("prompt", response_ja_no_tags), 0.3) + # Test custom no_match reward (wrong language) self.assertEqual(self.custom_reward("prompt", response_en), 0.1) - # Test no tags + # Test empty response self.assertEqual(self.custom_reward("prompt", response_none), 0.1) def test_call_zero_custom_values(self): From a2c02374b2243043fe2ba395bef1ecbaf4b8cb64 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 31 Oct 2025 14:53:14 -0700 Subject: [PATCH 04/24] Add debug logging and troubleshooting guide for LanguageReward - Add debug prints to RewardActor showing: * Reward value * Number of thinking blocks * Detected language * Response sample - Create debug_reward.py script for testing reward function * Tests 8 common scenarios * Shows detected language and confidence - Add TROUBLESHOOTING.md with solutions for: * Model thinking in English instead of Japanese * Empty or missing thinking blocks * Short content that langid can't detect * Stronger system prompt alternatives * Expected training progression This helps diagnose why LanguageReward might be constantly zero --- sandbox/grpo_language/TROUBLESHOOTING.md | 165 +++++++++++++++++++++++ sandbox/grpo_language/debug_reward.py | 92 +++++++++++++ sandbox/grpo_language/main.py | 26 ++++ 3 files changed, 283 insertions(+) create mode 100644 sandbox/grpo_language/TROUBLESHOOTING.md create mode 100644 sandbox/grpo_language/debug_reward.py diff --git a/sandbox/grpo_language/TROUBLESHOOTING.md b/sandbox/grpo_language/TROUBLESHOOTING.md new file mode 100644 index 000000000..4f077a2ae --- /dev/null +++ b/sandbox/grpo_language/TROUBLESHOOTING.md @@ -0,0 +1,165 @@ +# Troubleshooting LanguageReward Training + +## Issue: Language Reward is Always Zero + +If you're seeing the LanguageReward constantly at 0.0 during training, here's how to debug: + +### 1. Check What the Model is Generating + +The updated `main.py` includes debug logging. When you run training, look for lines like: + +``` +[LanguageReward Debug] Reward=0.00 | Blocks=1 | Lang=en | Sample: Let me solve this step by step...... +``` + +This tells you: +- **Reward**: The actual reward value +- **Blocks**: Number of thinking blocks found +- **Lang**: Language detected by langid +- **Sample**: First 80 chars of the response + +### 2. Common Causes and Solutions + +#### Cause 1: Model is Thinking in English + +**Symptom**: `Lang=en` in debug output + +**Why**: The model defaults to English because: +- The dataset (GSM8K) is in English +- Most models are English-dominant +- The instruction might not be strong enough + +**Solutions**: + +A) **Strengthen the system prompt** (edit `main.py` line 217-220): +```python +system_prompt = """ +あなたは数学の問題を解くAIです。タグの中で日本語で考えてください。これは必須です。 +Put all your scratchpad work between and tags. You MUST think in Japanese (日本語) inside the tags. +Your final answer should be between and tags otherwise it will not be scored. + +Example: +この問題を解きましょう。2 + 2 = 4です。 +4 +""" +``` + +B) **Start with higher language reward weight**: +In `main.py` line 327, you could add multiple LanguageReward instances: +```python +reward_functions=[ + MathReward(), + ThinkingReward(), + LanguageReward(target_language="ja"), + LanguageReward(target_language="ja"), # Double weight for language +] +``` + +C) **Use few-shot examples in the prompt**: +Add Japanese reasoning examples to each problem in the dataset transform. + +#### Cause 2: Model Not Using Thinking Blocks + +**Symptom**: `Blocks=0` in debug output + +**Why**: The model hasn't learned to use `` tags yet + +**Solution**: This should improve as ThinkingReward trains the model. Be patient for first few hundred steps. The fallback reward (0.2) should help when there are no blocks but Japanese text. + +#### Cause 3: Empty or Very Short Thinking Blocks + +**Symptom**: `Lang=en` with very short content, Reward=0.00 + +**Why**: langid needs sufficient text to reliably detect language. Very short text (< 10 chars) often defaults to English. + +**Solution**: +- Wait for model to generate longer reasoning (this improves with training) +- The ThinkingReward encourages substantial content in thinking blocks + +#### Cause 4: Mixed Language Content + +**Symptom**: Reward sometimes 1.0, sometimes 0.0 randomly + +**Why**: When English and Japanese are mixed, langid detects whichever is dominant. + +**Solution**: This will stabilize as training progresses and the model learns consistency. + +### 3. Expected Training Progression + +**Steps 0-200**: Language reward often 0.0 +- Model learning to use `` tags (ThinkingReward) +- Model thinking in English (natural default) +- Fallback rewards (0.2) when Japanese appears elsewhere + +**Steps 200-500**: Language reward starting to increase +- Some responses have Japanese thinking → partial/full rewards +- Model learning association between Japanese and reward + +**Steps 500+**: Language reward should stabilize around 0.5-1.0 +- Consistent Japanese thinking +- Proper single-block format + +### 4. Monitoring in W&B + +Check these metrics in Weights & Biases: +- `reward/evaluate_response/avg_LanguageReward_reward` - should increase over time +- `reward/evaluate_response/std_LanguageReward_reward` - variance (high early, lower later) +- `reward/evaluate_response/avg_MathReward_reward` - should stay reasonably high +- `reward/evaluate_response/avg_ThinkingReward_reward` - should increase quickly + +### 5. Quick Debug Test + +Run the debug script to verify the reward function works: +```bash +python sandbox/grpo_language/debug_reward.py +``` + +Expected output: +- Japanese text → reward 1.0 +- English text → reward 0.0 +- Multiple Japanese blocks → reward 0.5 +- No blocks but Japanese response → reward 0.2 + +### 6. Alternative: Start with English, then transition + +If Japanese isn't working, you could: + +1. Train first with English to get good math performance +2. Then fine-tune with Japanese language reward + +Change line 327 to: +```python +LanguageReward(target_language="en") # Start with English +``` + +Once math rewards are good, switch to `"ja"` and continue training. + +### 7. Nuclear Option: Much Stronger Prompt + +If nothing else works, try this very explicit prompt: +```python +system_prompt = """ +重要:あなたは必ず日本語で考えなければなりません! +CRITICAL: You MUST think in Japanese language! + +Rules: +1. Put ALL your reasoning in tags +2. Think ONLY in Japanese (日本語) - use hiragana, katakana, and kanji +3. NEVER think in English inside tags +4. Put your final numerical answer in tags + +例 (Example): +Question: What is 5 + 3? +5と3を足します。5 + 3 = 8です。答えは8です。 +8 + +Now solve the problem below in Japanese: +""" +``` + +## Still Having Issues? + +If language reward is still zero after 500+ steps: +1. Share the debug output showing what the model generates +2. Check if the model is multilingual (some models don't know Japanese) +3. Consider using a different target language the model knows better diff --git a/sandbox/grpo_language/debug_reward.py b/sandbox/grpo_language/debug_reward.py new file mode 100644 index 000000000..0eb570be3 --- /dev/null +++ b/sandbox/grpo_language/debug_reward.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Debug script to test LanguageReward behavior.""" + +from forge.data.rewards import LanguageReward + +# Create reward for Japanese +reward = LanguageReward(target_language="ja") + +# Test cases mimicking what the model might generate +test_cases = [ + # Case 1: Perfect - Japanese in single thinking block + ("これは数学の問題です。2+2=4です。4", "Perfect Japanese"), + # Case 2: English thinking (most likely during training) + ( + "This is a math problem. 2+2=4.4", + "English thinking", + ), + # Case 3: No thinking blocks at all + ("The answer is 4.4", "No thinking blocks"), + # Case 4: Empty thinking blocks + ("4", "Empty thinking block"), + # Case 5: Multiple thinking blocks in Japanese + ( + "最初の考え。次の考え。4", + "Multiple Japanese blocks", + ), + # Case 6: Just the answer, no thinking + ("4", "Just answer tag"), + # Case 7: Thinking with mostly numbers/symbols + ("2 + 2 = 44", "Mostly numbers"), + # Case 8: Mixed English and Japanese + ("Let me think... これは簡単です。4", "Mixed languages"), +] + +print("=" * 80) +print("LanguageReward Debug Output (target_language='ja')") +print("=" * 80) + +for response, description in test_cases: + score = reward(prompt="", response=response, target=None) + + import re + + # Try to detect what langid thinks + import langid + + # Extract thinking content if exists + think_match = re.findall( + r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", response, re.IGNORECASE | re.DOTALL + ) + + if think_match: + content = " ".join(think_match) + detected_lang, confidence = langid.classify(content) + print(f"\n{description}:") + print(f" Response: {response[:60]}...") + print(f" Reward: {score}") + print(f" Thinking blocks found: {len(think_match)}") + print(f" Detected language: {detected_lang} (confidence: {confidence:.3f})") + else: + # Check fallback + response_text = re.sub( + r"<\s*/?\s*think\s*>", "", response, flags=re.IGNORECASE + ).strip() + if response_text: + detected_lang, confidence = langid.classify(response_text) + print(f"\n{description}:") + print(f" Response: {response[:60]}...") + print(f" Reward: {score}") + print(" Thinking blocks found: 0") + print( + f" Fallback detection on response text: {detected_lang} (confidence: {confidence:.3f})" + ) + else: + print(f"\n{description}:") + print(f" Response: {response[:60]}...") + print(f" Reward: {score}") + print(" No content to analyze") + +print("\n" + "=" * 80) +print("Expected rewards:") +print(" full_reward (1.0): Single Japanese thinking block") +print(" partial_reward (0.5): Multiple Japanese thinking blocks") +print(" fallback_reward (0.2): No blocks but Japanese response text") +print(" no_match_reward (0.0): Wrong language") +print("=" * 80) diff --git a/sandbox/grpo_language/main.py b/sandbox/grpo_language/main.py index 75219c39c..0b049873a 100644 --- a/sandbox/grpo_language/main.py +++ b/sandbox/grpo_language/main.py @@ -154,6 +154,32 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl reward_fn_name = getattr( reward_fn, "__name__", reward_fn.__class__.__name__ ) + + # Debug logging for LanguageReward to see what's happening + if reward_fn_name == "LanguageReward": + import re + + import langid + + think_matches = re.findall( + r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", + response, + re.IGNORECASE | re.DOTALL, + ) + if think_matches: + content = " ".join(think_matches) + detected_lang, confidence = langid.classify(content) + print( + f"[LanguageReward Debug] Reward={reward:.2f} | " + f"Blocks={len(think_matches)} | Lang={detected_lang} | " + f"Sample: {response[:80]}..." + ) + else: + print( + f"[LanguageReward Debug] Reward={reward:.2f} | " + f"Blocks=0 | Sample: {response[:80]}..." + ) + # per function reward record_metric( f"reward/evaluate_response/sum_{reward_fn_name}_reward", From afca75cd0c2794a1efee9f5b10a43a852af4697a Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 31 Oct 2025 15:01:22 -0700 Subject: [PATCH 05/24] Add debug printing to LanguageReward and strengthen system prompt MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Core Changes: - Add debug mode to LanguageReward class * debug parameter: enables debug printing * debug_sample_rate: controls sampling (default 0.1 = 10%) * Prints detected language, confidence, reward, and sample text * Shows why reward was given (e.g., 'single block, correct language ✓') - Strengthen system prompt for Japanese thinking * Add Japanese instructions at top (あなたは数学の問題を解く...) * List explicit CRITICAL RULES * Include concrete example in Japanese * Much more forceful about requiring Japanese in tags - Enable debug in sandbox app * Set debug=True, debug_sample_rate=0.1 for LanguageReward * Will print every ~10th response with language detection details Example debug output: [LanguageReward] Found 1 thinking block(s) Target: ja | Detected: en | Confidence: -54.41 Thinking sample: Let me solve this step by step... → Reward: 0.0 (wrong language) ✗ This should help diagnose why rewards are zero during training --- sandbox/grpo_language/main.py | 20 ++++++++++-- src/forge/data/rewards.py | 57 +++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 3 deletions(-) diff --git a/sandbox/grpo_language/main.py b/sandbox/grpo_language/main.py index 0b049873a..a6fb67c99 100644 --- a/sandbox/grpo_language/main.py +++ b/sandbox/grpo_language/main.py @@ -241,8 +241,20 @@ def setup(self): def gsm8k_transform(sample): system_prompt = """ - Put all your scratchpad work between and tags. You must think in Japanese inside the tags. - Your final answer should be between and tags otherwise it will not be scored. +あなたは数学の問題を解くAIアシスタントです。以下の重要なルールに従ってください: + +CRITICAL RULES: +1. Put ALL your reasoning inside and tags +2. You MUST think in Japanese (日本語) inside the tags - use hiragana, katakana, and kanji +3. NEVER use English inside tags +4. Put your final numerical answer inside and tags + +Example: +Question: What is 12 + 5? +12と5を足します。12 + 5 = 17です。したがって、答えは17です。 +17 + +Now solve the following problem using Japanese in your tags: """ request: str = sample["question"] as_chat = [ @@ -347,7 +359,9 @@ async def main(cfg: DictConfig): reward_functions=[ MathReward(), ThinkingReward(), - LanguageReward(target_language="ja"), # Japanese language reward + LanguageReward( + target_language="ja", debug=True, debug_sample_rate=0.1 + ), # Japanese language reward with debug ] ), ) diff --git a/src/forge/data/rewards.py b/src/forge/data/rewards.py index 64a6c856e..834800e3d 100644 --- a/src/forge/data/rewards.py +++ b/src/forge/data/rewards.py @@ -94,6 +94,8 @@ class LanguageReward: partial_reward: Reward when language matches but format is wrong (multiple blocks) fallback_reward: Reward when no valid blocks but response text is in target language no_match_reward: Reward when language doesn't match + debug: If True, print debug samples showing model outputs and detected language + debug_sample_rate: Fraction of calls to debug (e.g., 0.1 = 10% of calls) Note: Requires langid to be installed. Install with: pip install langid """ @@ -105,12 +107,17 @@ def __init__( partial_reward: float = 0.5, fallback_reward: float = 0.2, no_match_reward: float = 0.0, + debug: bool = False, + debug_sample_rate: float = 0.1, ): self.target_language = target_language self.full_reward = full_reward self.partial_reward = partial_reward self.fallback_reward = fallback_reward self.no_match_reward = no_match_reward + self.debug = debug + self.debug_sample_rate = debug_sample_rate + self._debug_counter = 0 self._THINK_BLOCK_RE = re.compile( r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", re.IGNORECASE | re.DOTALL ) @@ -140,6 +147,14 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo fallback_reward if no valid blocks but response text is in target language, no_match_reward otherwise (wrong language) """ + # Increment counter for sampling + self._debug_counter += 1 + should_debug = ( + self.debug + and self.debug_sample_rate > 0 + and (self._debug_counter % int(1 / self.debug_sample_rate)) == 0 + ) + if not response: return self.no_match_reward @@ -154,15 +169,34 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo ).strip() if not response_text: + if should_debug: + print( + f"\n[LanguageReward] Empty response | Reward: {self.no_match_reward}" + ) return self.no_match_reward # Detect language of general response detected_lang, confidence = self._langid.classify(response_text) + if should_debug: + sample = response[:150].replace("\n", " ") + print( + f"\n[LanguageReward] No thinking blocks found (FALLBACK mode)" + f"\n Target: {self.target_language} | Detected: {detected_lang} | " + f"Confidence: {confidence:.2f}" + f"\n Sample: {sample}..." + ) + # Give fallback reward if response is in target language if detected_lang == self.target_language: + if should_debug: + print( + f" → Reward: {self.fallback_reward} (fallback, correct language)" + ) return self.fallback_reward + if should_debug: + print(f" → Reward: {self.no_match_reward} (wrong language)") return self.no_match_reward # Concatenate all thinking blocks for language detection @@ -172,18 +206,41 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo thinking_content = re.sub(r"\s+", " ", thinking_content).strip() if not thinking_content: + if should_debug: + print( + f"\n[LanguageReward] Empty thinking blocks | Reward: {self.no_match_reward}" + ) return self.no_match_reward # Detect language using langid detected_lang, confidence = self._langid.classify(thinking_content) + if should_debug: + sample = thinking_content[:150].replace("\n", " ") + print( + f"\n[LanguageReward] Found {len(matches)} thinking block(s)" + f"\n Target: {self.target_language} | Detected: {detected_lang} | " + f"Confidence: {confidence:.2f}" + f"\n Thinking sample: {sample}..." + ) + # Check if language matches target if detected_lang == self.target_language: # Full reward for correct format (single block) if len(matches) == 1: + if should_debug: + print( + f" → Reward: {self.full_reward} (single block, correct language) ✓" + ) return self.full_reward # Partial reward for wrong format (multiple blocks) but correct language else: + if should_debug: + print( + f" → Reward: {self.partial_reward} (multiple blocks, correct language)" + ) return self.partial_reward + if should_debug: + print(f" → Reward: {self.no_match_reward} (wrong language) ✗") return self.no_match_reward From 2625b282d392c0482dc39b66614d5d0cd9701986 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 31 Oct 2025 15:19:31 -0700 Subject: [PATCH 06/24] =?UTF-8?q?Refactor=20to=20use=20configurable=20Japa?= =?UTF-8?q?nese=20tags=20<=E6=80=9D=E8=80=83>=20instead=20of=20English=20?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BREAKING CHANGE: Default tag for LanguageReward changed from 'think' to '思考' Key changes: - Both ThinkingReward and LanguageReward now accept 'tag' parameter - ThinkingReward default: 'think' (backward compatible) - LanguageReward default: '思考' (Japanese, breaks English associations) - Sandbox app uses <思考> tags throughout - System prompt updated to Japanese with <思考> examples - All tests updated and passing (29/29) - Debug script updated Rationale: Models may be heavily trained to think in English when using tags. Using Japanese tags <思考> (shikō = 'thinking') breaks this association and encourages thinking in the target language. --- sandbox/grpo_language/README.md | 28 ++++++--- sandbox/grpo_language/debug_reward.py | 20 +++--- sandbox/grpo_language/main.py | 20 +++--- src/forge/data/rewards.py | 34 ++++++++--- tests/unit_tests/rl/test_language_reward.py | 67 +++++++++++---------- 5 files changed, 97 insertions(+), 72 deletions(-) diff --git a/sandbox/grpo_language/README.md b/sandbox/grpo_language/README.md index 46c573cb8..ccab37e04 100644 --- a/sandbox/grpo_language/README.md +++ b/sandbox/grpo_language/README.md @@ -4,19 +4,25 @@ This sandbox app demonstrates using GRPO training with a language reward that en ## Overview -This app extends the standard GRPO training (from `apps/grpo/`) by adding a `LanguageReward` that evaluates whether the model's thinking (text within `` tags) is in the target language. +This app extends the standard GRPO training (from `apps/grpo/`) by adding a `LanguageReward` that evaluates whether the model's thinking (text within `<思考>` tags) is in the target language. + +**Key Insight**: Uses Japanese tags `<思考>` (shikō = "thinking") instead of English `` tags to break the model's association between thinking tags and English language. This helps encourage multilingual thinking. ## Key Features - **Multi-objective training**: Combines three rewards: - `MathReward`: Evaluates correctness of math answers - - `ThinkingReward`: Encourages use of thinking tags + - `ThinkingReward`: Encourages use of `<思考>` tags - `LanguageReward`: Rewards thinking in target language (Japanese by default) +- **Japanese thinking tags**: Uses `<思考>` instead of `` to encourage non-English reasoning + - **Language detection**: Uses `langid` to detect the language of thinking blocks - **Configurable target language**: While this app defaults to Japanese (`ja`), the `LanguageReward` can be configured for any ISO 639-1 language code +- **Configurable tags**: Both rewards support custom tag names via the `tag` parameter + ## Requirements Before running this app, install the required language detection library: @@ -33,25 +39,29 @@ python -m sandbox.grpo_language.main --config sandbox/grpo_language/qwen3_1_7b.y ## How It Works -1. The model receives a math problem and is instructed to use `` tags for reasoning +1. The model receives a math problem and is instructed to use `<思考>` tags for reasoning 2. During training, the model generates responses with thinking blocks 3. Three rewards are computed: - Math correctness (did it get the right answer?) - - Thinking usage (did it use thinking tags properly?) + - Thinking usage (did it use `<思考>` tags properly?) - Language usage (did it think in Japanese?) 4. The model is trained to maximize all three rewards ## Configuration -The target language is hardcoded as Japanese in `main.py` (line 321): +### Target Language + +The target language is configured as Japanese in `main.py`: ```python -LanguageReward(target_language="ja") +LanguageReward(target_language="ja", tag="思考") +ThinkingReward(tag="思考") ``` -To use a different language, modify this line with the appropriate ISO 639-1 code: -- English: `"en"` -- Chinese: `"zh"` +To use a different language: +1. Change `target_language` to the appropriate ISO 639-1 code: + - English: `"en"` + - Chinese: `"zh"` - Spanish: `"es"` - French: `"fr"` - etc. diff --git a/sandbox/grpo_language/debug_reward.py b/sandbox/grpo_language/debug_reward.py index 0eb570be3..b79afe2a6 100644 --- a/sandbox/grpo_language/debug_reward.py +++ b/sandbox/grpo_language/debug_reward.py @@ -15,27 +15,27 @@ # Test cases mimicking what the model might generate test_cases = [ # Case 1: Perfect - Japanese in single thinking block - ("これは数学の問題です。2+2=4です。4", "Perfect Japanese"), + ("<思考>これは数学の問題です。2+2=4です。4", "Perfect Japanese"), # Case 2: English thinking (most likely during training) ( - "This is a math problem. 2+2=4.4", + "<思考>This is a math problem. 2+2=4.4", "English thinking", ), # Case 3: No thinking blocks at all ("The answer is 4.4", "No thinking blocks"), # Case 4: Empty thinking blocks - ("4", "Empty thinking block"), + ("<思考>4", "Empty thinking block"), # Case 5: Multiple thinking blocks in Japanese ( - "最初の考え。次の考え。4", + "<思考>最初の考え。<思考>次の考え。4", "Multiple Japanese blocks", ), # Case 6: Just the answer, no thinking ("4", "Just answer tag"), # Case 7: Thinking with mostly numbers/symbols - ("2 + 2 = 44", "Mostly numbers"), + ("<思考>2 + 2 = 44", "Mostly numbers"), # Case 8: Mixed English and Japanese - ("Let me think... これは簡単です。4", "Mixed languages"), + ("<思考>Let me think... これは簡単です。4", "Mixed languages"), ] print("=" * 80) @@ -51,9 +51,7 @@ import langid # Extract thinking content if exists - think_match = re.findall( - r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", response, re.IGNORECASE | re.DOTALL - ) + think_match = re.findall(r"<\s*思考\s*>(.*?)<\s*/\s*思考\s*>", response, re.DOTALL) if think_match: content = " ".join(think_match) @@ -65,9 +63,7 @@ print(f" Detected language: {detected_lang} (confidence: {confidence:.3f})") else: # Check fallback - response_text = re.sub( - r"<\s*/?\s*think\s*>", "", response, flags=re.IGNORECASE - ).strip() + response_text = re.sub(r"<\s*/?\s*思考\s*>", "", response).strip() if response_text: detected_lang, confidence = langid.classify(response_text) print(f"\n{description}:") diff --git a/sandbox/grpo_language/main.py b/sandbox/grpo_language/main.py index a6fb67c99..bd4633722 100644 --- a/sandbox/grpo_language/main.py +++ b/sandbox/grpo_language/main.py @@ -243,18 +243,18 @@ def gsm8k_transform(sample): system_prompt = """ あなたは数学の問題を解くAIアシスタントです。以下の重要なルールに従ってください: -CRITICAL RULES: -1. Put ALL your reasoning inside and tags -2. You MUST think in Japanese (日本語) inside the tags - use hiragana, katakana, and kanji -3. NEVER use English inside tags -4. Put your final numerical answer inside and tags +重要なルール (CRITICAL RULES): +1. すべての思考過程を <思考> と タグの中に入れてください +2. <思考> タグの中では必ず日本語で考えてください(ひらがな、カタカナ、漢字を使用) +3. <思考> タグの中では絶対に英語を使わないでください +4. 最終的な数値の答えを タグの中に入れてください -Example: +例 (Example): Question: What is 12 + 5? -12と5を足します。12 + 5 = 17です。したがって、答えは17です。 +<思考>12と5を足します。12 + 5 = 17です。したがって、答えは17です。 17 -Now solve the following problem using Japanese in your tags: +以下の問題を <思考> タグの中で日本語を使って解いてください: """ request: str = sample["question"] as_chat = [ @@ -358,9 +358,9 @@ async def main(cfg: DictConfig): RewardActor.options(**cfg.services.reward_actor).as_service( reward_functions=[ MathReward(), - ThinkingReward(), + ThinkingReward(tag="思考"), # Use Japanese tag LanguageReward( - target_language="ja", debug=True, debug_sample_rate=0.1 + target_language="ja", tag="思考", debug=True, debug_sample_rate=0.1 ), # Japanese language reward with debug ] ), diff --git a/src/forge/data/rewards.py b/src/forge/data/rewards.py index 834800e3d..a966d5c86 100644 --- a/src/forge/data/rewards.py +++ b/src/forge/data/rewards.py @@ -57,15 +57,28 @@ def _to_float(self, text: str) -> float | None: class ThinkingReward: - """Reward class for evaluating use of tags in reasoning.""" + """Reward class for evaluating use of thinking tags in reasoning. - def __init__(self, partial_reward: float = 0.2, full_reward: float = 1.0): + Args: + partial_reward: Reward for partial tag usage (incomplete/malformed) + full_reward: Reward for well-formed thinking blocks with content + tag: Tag name to use (default "think", can use "思考" for Japanese, etc.) + """ + + def __init__( + self, partial_reward: float = 0.2, full_reward: float = 1.0, tag: str = "think" + ): self.partial_reward = partial_reward self.full_reward = full_reward + self.tag = tag + # Build regex patterns for the specified tag self._THINK_BLOCK_RE = re.compile( - r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", re.IGNORECASE | re.DOTALL + rf"<\s*{re.escape(tag)}\s*>(.*?)<\s*/\s*{re.escape(tag)}\s*>", + re.IGNORECASE | re.DOTALL, + ) + self._THINK_TAG_ATTEMPT_RE = re.compile( + rf"<\s*/?\s*{re.escape(tag)}\s*>", re.IGNORECASE ) - self._THINK_TAG_ATTEMPT_RE = re.compile(r"<\s*/?\s*think\s*>", re.IGNORECASE) def __call__(self, prompt: str, response: str, target: str | None = None) -> float: """Compute thinking reward.""" @@ -83,7 +96,7 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo class LanguageReward: - """Reward class for evaluating the language used in tags. + """Reward class for evaluating the language used in thinking tags. This reward uses langid to detect the language of text within thinking blocks and rewards responses that use the target language. @@ -94,6 +107,7 @@ class LanguageReward: partial_reward: Reward when language matches but format is wrong (multiple blocks) fallback_reward: Reward when no valid blocks but response text is in target language no_match_reward: Reward when language doesn't match + tag: Tag name to use (default "思考" for multilingual, can use "think", etc.) debug: If True, print debug samples showing model outputs and detected language debug_sample_rate: Fraction of calls to debug (e.g., 0.1 = 10% of calls) @@ -107,6 +121,7 @@ def __init__( partial_reward: float = 0.5, fallback_reward: float = 0.2, no_match_reward: float = 0.0, + tag: str = "思考", debug: bool = False, debug_sample_rate: float = 0.1, ): @@ -115,12 +130,15 @@ def __init__( self.partial_reward = partial_reward self.fallback_reward = fallback_reward self.no_match_reward = no_match_reward + self.tag = tag self.debug = debug self.debug_sample_rate = debug_sample_rate self._debug_counter = 0 + # Build regex pattern for the specified tag self._THINK_BLOCK_RE = re.compile( - r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", re.IGNORECASE | re.DOTALL + rf"<\s*{re.escape(tag)}\s*>(.*?)<\s*/\s*{re.escape(tag)}\s*>", re.DOTALL ) + self._TAG_PATTERN = rf"<\s*/?\s*{re.escape(tag)}\s*>" # Lazy import langid with helpful error message try: @@ -164,9 +182,7 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo # If no thinking blocks found, check if response text is in target language if len(matches) == 0: # Remove any partial tags that might exist - response_text = re.sub( - r"<\s*/?\s*think\s*>", "", response, flags=re.IGNORECASE - ).strip() + response_text = re.sub(self._TAG_PATTERN, "", response).strip() if not response_text: if should_debug: diff --git a/tests/unit_tests/rl/test_language_reward.py b/tests/unit_tests/rl/test_language_reward.py index b88846fcf..1df1eb460 100644 --- a/tests/unit_tests/rl/test_language_reward.py +++ b/tests/unit_tests/rl/test_language_reward.py @@ -81,13 +81,13 @@ def test_regex_pattern(self): def test_call_with_english_thinking(self): """Test __call__ with English text in thinking blocks.""" - response = "This is English reasoning about math problems." + response = "<思考>This is English reasoning about math problems." result = self.reward_en("prompt", response) self.assertEqual(result, 1.0) def test_call_with_japanese_thinking(self): """Test __call__ with Japanese text in thinking blocks.""" - response = "これは日本語で考えています。数学の問題を解きます。" + response = "<思考>これは日本語で考えています。数学の問題を解きます。" result = self.reward_ja("prompt", response) self.assertEqual(result, 1.0) @@ -97,7 +97,7 @@ def test_call_with_japanese_thinking(self): def test_call_with_chinese_thinking(self): """Test __call__ with Chinese text in thinking blocks.""" - response = "这是中文思考。我们需要解决这个数学问题。" + response = "<思考>这是中文思考。我们需要解决这个数学问题。" reward_zh = self.LanguageReward(target_language="zh") result = reward_zh("prompt", response) # langid should detect this as Chinese (zh) @@ -105,7 +105,9 @@ def test_call_with_chinese_thinking(self): def test_call_with_spanish_thinking(self): """Test __call__ with Spanish text in thinking blocks.""" - response = "Este es un razonamiento en español sobre problemas matemáticos." + response = ( + "<思考>Este es un razonamiento en español sobre problemas matemáticos." + ) reward_es = self.LanguageReward(target_language="es") result = reward_es("prompt", response) # langid should detect this as Spanish (es) @@ -114,12 +116,12 @@ def test_call_with_spanish_thinking(self): def test_call_language_mismatch(self): """Test __call__ when detected language doesn't match target.""" # Japanese reward with English text - response = "This is English reasoning." + response = "<思考>This is English reasoning." result = self.reward_ja("prompt", response) self.assertEqual(result, 0.0) # English reward with Japanese text - response = "これは日本語です。" + response = "<思考>これは日本語です。" result = self.reward_en("prompt", response) self.assertEqual(result, 0.0) @@ -139,36 +141,37 @@ def test_call_with_no_thinking_tags_wrong_language(self): def test_call_with_empty_thinking_block(self): """Test __call__ with empty thinking block.""" - result = self.reward_en("prompt", "") + result = self.reward_en("prompt", "<思考>") self.assertEqual(result, 0.0) def test_call_with_whitespace_only_thinking_block(self): """Test __call__ with whitespace-only thinking block.""" - result = self.reward_en("prompt", " \n \t ") + result = self.reward_en("prompt", "<思考> \n \t ") self.assertEqual(result, 0.0) - def test_call_case_insensitive(self): - """Test __call__ is case insensitive for thinking tags.""" - response = "This is English reasoning." + def test_call_with_proper_tags(self): + """Test __call__ with properly formatted thinking tags.""" + response = "<思考>This is English reasoning." result = self.reward_en("prompt", response) self.assertEqual(result, 1.0) - response = "This is English reasoning." - result = self.reward_en("prompt", response) + # Japanese content should also work + response = "<思考>これは日本語です。" + result = self.reward_ja("prompt", response) self.assertEqual(result, 1.0) def test_call_with_whitespace_in_tags(self): """Test __call__ with whitespace in thinking tags.""" - response = "< think >This is English reasoning." + response = "< 思考 >This is English reasoning." result = self.reward_en("prompt", response) self.assertEqual(result, 1.0) def test_call_multiple_thinking_blocks(self): """Test __call__ with multiple thinking blocks (wrong format but correct language).""" response = """ - First thought in English. + <思考>First thought in English. Some text in between. - Second thought also in English. + <思考>Second thought also in English. """ result = self.reward_en("prompt", response) # Multiple blocks = wrong format, but language is correct, should return partial_reward @@ -177,8 +180,8 @@ def test_call_multiple_thinking_blocks(self): def test_call_multiple_thinking_blocks_mixed_languages(self): """Test __call__ with multiple thinking blocks in different languages (wrong format).""" response = """ - First thought in English with lots of content here. - これは短い日本語。 + <思考>First thought in English with lots of content here. + <思考>これは短い日本語。 """ result = self.reward_en("prompt", response) # Multiple blocks with mixed languages - langid will detect dominant language @@ -187,12 +190,12 @@ def test_call_multiple_thinking_blocks_mixed_languages(self): def test_call_multiline_thinking_block(self): """Test __call__ with multiline thinking blocks.""" - response = """ + response = """<思考> This is a multiline thinking block with lots of English content about solving problems - """ + """ result = self.reward_en("prompt", response) self.assertEqual(result, 1.0) @@ -208,7 +211,7 @@ def test_call_none_response(self): def test_call_with_target_parameter(self): """Test __call__ with target parameter (should be ignored).""" - response = "This is English reasoning." + response = "<思考>This is English reasoning." result = self.reward_en("prompt", response, target="some target") self.assertEqual(result, 1.0) @@ -222,10 +225,10 @@ def test_call_with_target_parameter(self): def test_call_custom_reward_values(self): """Test __call__ with custom reward values.""" - response_ja_single = "これは日本語です。" - response_ja_multiple = "最初の考え。次の考え。" + response_ja_single = "<思考>これは日本語です。" + response_ja_multiple = "<思考>最初の考え。<思考>次の考え。" response_ja_no_tags = "これはタグなしの日本語です。" - response_en = "This is English." + response_en = "<思考>This is English." response_none = "" # Test custom full reward (single block, correct language) @@ -244,13 +247,13 @@ def test_call_zero_custom_values(self): zero_reward = self.LanguageReward( target_language="en", full_reward=0.0, no_match_reward=0.0 ) - result = zero_reward("prompt", "This is English.") + result = zero_reward("prompt", "<思考>This is English.") self.assertEqual(result, 0.0) def test_call_with_special_characters(self): """Test __call__ with special characters in thinking blocks.""" response = ( - "English with special chars: @#$%^&*()_+-=[]{}|;':\",./<>?`~" + "<思考>English with special chars: @#$%^&*()_+-=[]{}|;':\",./<>?`~" ) result = self.reward_en("prompt", response) self.assertEqual(result, 1.0) @@ -260,7 +263,7 @@ def test_call_with_mixed_content_outside_tags(self): # Content outside think tags should be ignored response = """ これは日本語のテキストです。 - But this is English reasoning inside the tags. + <思考>But this is English reasoning inside the tags. もっと日本語のテキスト。 """ result = self.reward_en("prompt", response) @@ -269,7 +272,7 @@ def test_call_with_mixed_content_outside_tags(self): def test_call_with_numbers_and_symbols(self): """Test __call__ with thinking blocks containing mostly numbers.""" - response = "Calculate: 2 + 2 = 4, then 4 * 3 = 12" + response = "<思考>Calculate: 2 + 2 = 4, then 4 * 3 = 12" result = self.reward_en("prompt", response) # Should still detect as English due to words like "Calculate" and "then" self.assertEqual(result, 1.0) @@ -277,17 +280,17 @@ def test_call_with_numbers_and_symbols(self): def test_call_very_long_thinking_block(self): """Test __call__ with very long thinking blocks.""" long_content = "This is English content. " * 1000 - result = self.reward_en("prompt", f"{long_content}") + result = self.reward_en("prompt", f"<思考>{long_content}") self.assertEqual(result, 1.0) def test_call_with_code_in_thinking(self): """Test __call__ with code snippets in thinking blocks.""" - response = """ + response = """<思考> Let me write some Python code to solve this: def calculate(x): return x * 2 The function doubles the input value. - """ + """ result = self.reward_en("prompt", response) # Should detect as English due to surrounding text self.assertEqual(result, 1.0) @@ -304,7 +307,7 @@ def test_different_language_codes(self): for lang_code, text in languages.items(): reward = self.LanguageReward(target_language=lang_code) - response = f"{text}" + response = f"<思考>{text}" result = reward("prompt", response) # langid should detect these correctly self.assertEqual( From 1a4d5fb0c266c1f84b4be005110aaf80c65b3f8e Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 31 Oct 2025 15:37:56 -0700 Subject: [PATCH 07/24] Remove old debug code from main.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The old debug code was checking for tags instead of <思考> tags, which was misleading. The LanguageReward class now has its own debug mode that correctly checks for the configured tag. --- sandbox/grpo_language/main.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/sandbox/grpo_language/main.py b/sandbox/grpo_language/main.py index bd4633722..9014e3b32 100644 --- a/sandbox/grpo_language/main.py +++ b/sandbox/grpo_language/main.py @@ -155,31 +155,6 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl reward_fn, "__name__", reward_fn.__class__.__name__ ) - # Debug logging for LanguageReward to see what's happening - if reward_fn_name == "LanguageReward": - import re - - import langid - - think_matches = re.findall( - r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", - response, - re.IGNORECASE | re.DOTALL, - ) - if think_matches: - content = " ".join(think_matches) - detected_lang, confidence = langid.classify(content) - print( - f"[LanguageReward Debug] Reward={reward:.2f} | " - f"Blocks={len(think_matches)} | Lang={detected_lang} | " - f"Sample: {response[:80]}..." - ) - else: - print( - f"[LanguageReward Debug] Reward={reward:.2f} | " - f"Blocks=0 | Sample: {response[:80]}..." - ) - # per function reward record_metric( f"reward/evaluate_response/sum_{reward_fn_name}_reward", From 4e87a4d6e69e7ebf5a5053802a99fd55e6e97819 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 31 Oct 2025 16:11:38 -0700 Subject: [PATCH 08/24] Weaken system prompt to rely more on RL rewards MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changed from strong Japanese instructions with CRITICAL RULES to a simple English prompt that just mentions the <思考> tags with a Japanese example. This allows the RL training to do the work of encouraging Japanese thinking through rewards rather than relying on heavy prompt engineering. --- sandbox/grpo_language/main.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/sandbox/grpo_language/main.py b/sandbox/grpo_language/main.py index 9014e3b32..bd8afc304 100644 --- a/sandbox/grpo_language/main.py +++ b/sandbox/grpo_language/main.py @@ -216,20 +216,14 @@ def setup(self): def gsm8k_transform(sample): system_prompt = """ -あなたは数学の問題を解くAIアシスタントです。以下の重要なルールに従ってください: +You are a helpful AI assistant that solves math problems. -重要なルール (CRITICAL RULES): -1. すべての思考過程を <思考> と タグの中に入れてください -2. <思考> タグの中では必ず日本語で考えてください(ひらがな、カタカナ、漢字を使用) -3. <思考> タグの中では絶対に英語を使わないでください -4. 最終的な数値の答えを タグの中に入れてください +Please show your reasoning inside <思考> tags, then provide your final numerical answer inside tags. -例 (Example): +Example: Question: What is 12 + 5? -<思考>12と5を足します。12 + 5 = 17です。したがって、答えは17です。 +<思考>12と5を足します。12 + 5 = 17です。 17 - -以下の問題を <思考> タグの中で日本語を使って解いてください: """ request: str = sample["question"] as_chat = [ From abb653ee647e04e7756a760c80f13f84c3c204c1 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 31 Oct 2025 16:40:52 -0700 Subject: [PATCH 09/24] Remove sandbox config and reference apps/grpo configs instead MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changes: - Removed sandbox/grpo_language/qwen3_1_7b.yaml (use configs from apps/grpo/) - Updated usage comment in main.py to reference apps/grpo/qwen3_1_7b.yaml - Updated README.md to reference apps/grpo/ configs - Fixed README.md to use <思考> tags instead of tags --- sandbox/grpo_language/README.md | 6 +- sandbox/grpo_language/main.py | 2 +- sandbox/grpo_language/qwen3_1_7b.yaml | 151 -------------------------- 3 files changed, 5 insertions(+), 154 deletions(-) delete mode 100644 sandbox/grpo_language/qwen3_1_7b.yaml diff --git a/sandbox/grpo_language/README.md b/sandbox/grpo_language/README.md index ccab37e04..cfa97af40 100644 --- a/sandbox/grpo_language/README.md +++ b/sandbox/grpo_language/README.md @@ -34,9 +34,11 @@ pip install langid ## Usage ```bash -python -m sandbox.grpo_language.main --config sandbox/grpo_language/qwen3_1_7b.yaml +python -m sandbox.grpo_language.main --config apps/grpo/qwen3_1_7b.yaml ``` +You can use any of the config files from `apps/grpo/` (e.g., `qwen3_1_7b.yaml`, `qwen3_8b.yaml`, `qwen3_32b.yaml`). + ## How It Works 1. The model receives a math problem and is instructed to use `<思考>` tags for reasoning @@ -70,7 +72,7 @@ To use a different language: Over the course of training, the model should learn to: 1. Solve math problems correctly -2. Use `` tags for its reasoning +2. Use `<思考>` tags for its reasoning 3. Write its thinking in Japanese (or the configured target language) ## Metrics diff --git a/sandbox/grpo_language/main.py b/sandbox/grpo_language/main.py index bd8afc304..c27c55638 100644 --- a/sandbox/grpo_language/main.py +++ b/sandbox/grpo_language/main.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# Usage: python -m sandbox.grpo_language.main --config sandbox/grpo_language/qwen3_1_7b.yaml +# Usage: python -m sandbox.grpo_language.main --config apps/grpo/qwen3_1_7b.yaml import asyncio import time diff --git a/sandbox/grpo_language/qwen3_1_7b.yaml b/sandbox/grpo_language/qwen3_1_7b.yaml deleted file mode 100644 index e6715a219..000000000 --- a/sandbox/grpo_language/qwen3_1_7b.yaml +++ /dev/null @@ -1,151 +0,0 @@ -# Grouped Relative Policy Optimization (GRPO) with Language Reward -# >>> python -m sandbox.grpo_language.main --config sandbox/grpo_language/qwen3_1_7b.yaml - -# Global configuration -group_size: 8 -local_batch_size: 16 # per-device batch size -max_req_tokens: 1024 -max_res_tokens: 1024 -model: "Qwen/Qwen3-1.7B" -off_by_n: 1 # Off by one by default - -# Main loop configuration -rollout_threads: 1 # Recommended to set equal to policy.num_replicas - - -# Observability configuration -metric_logging: - wandb: - project: grpo-training - group: grpo_language_exp_${oc.env:USER} - logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce - console: - logging_mode: global_reduce - -# Dataset configuration -dataset: - path: "openai/gsm8k" - revision: "main" - data_split: "train" - streaming: true - model: ${model} - -# Policy configuration -policy: - engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs - model: ${model} - tensor_parallel_size: 1 - pipeline_parallel_size: 1 - enforce_eager: false - sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams - n: ${group_size} - max_tokens: ${max_res_tokens} - temperature: 1.0 - top_p: 1.0 - -# Trainer configuration -trainer: - model: - name: qwen3 - flavor: 1.7B - hf_assets_path: hf://${model} - optimizer: - name: AdamW - lr: 1e-5 - eps: 1e-8 - lr_scheduler: - warmup_steps: 1 - training: - local_batch_size: ${local_batch_size} - seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens - max_norm: 1.0 - steps: 1000000 - dtype: bfloat16 - gc_freq: 1 - compile: - enable: false - parallelism: - data_parallel_replicate_degree: 1 - data_parallel_shard_degree: 1 - tensor_parallel_degree: 1 - pipeline_parallel_degree: 1 - context_parallel_degree: 1 - expert_parallel_degree: 1 - disable_loss_parallel: true - checkpoint: - enable: true - folder: ./checkpoint # The folder to save checkpoints to. - initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists. - initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo - last_save_in_hf: true - interval: 500 - async_mode: "disabled" - activation_checkpoint: - mode: selective - selective_ac_option: op - -# Replay buffer configuration -replay_buffer: - batch_size: ${local_batch_size} - max_policy_age: ${off_by_n} - dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree - -# Reference model configuration -ref_model: - model: - name: qwen3 - flavor: 1.7B - hf_assets_path: hf://${model} - training: - seq_len: ${trainer.training.seq_len} - dtype: bfloat16 - gc_freq: 1 - compile: - enable: false - parallelism: - data_parallel_replicate_degree: 1 - data_parallel_shard_degree: 1 - tensor_parallel_degree: 1 - pipeline_parallel_degree: 1 - context_parallel_degree: 1 - expert_parallel_degree: 1 - checkpoint: - enable: true - initial_load_path: hf://${model} - initial_load_in_hf: true - -# All resource allocations -services: - policy: - procs: ${policy.engine_args.tensor_parallel_size} - num_replicas: 1 - mesh_name: policy - with_gpus: true - ref_model: - procs: 1 - num_replicas: 1 - mesh_name: ref_model - with_gpus: true - reward_actor: - procs: 1 - num_replicas: 1 - mesh_name: reward_actor - with_gpus: false - -actors: - dataset: - procs: 1 - with_gpus: false - mesh_name: dataset - trainer: - procs: 1 - with_gpus: true - mesh_name: trainer - replay_buffer: - procs: 1 - with_gpus: false - mesh_name: replay_buffer - compute_advantages: - procs: 1 - with_gpus: false - mesh_name: compute_advantages From 7b4829c6edf255b504c66e67151e2a45407e0bd7 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 31 Oct 2025 16:47:08 -0700 Subject: [PATCH 10/24] Simplify LanguageReward logic to focus on language detection only Since ThinkingReward already enforces format (single vs multiple blocks), LanguageReward now focuses purely on language detection with simplified logic: Detection strategy: - If exactly one thinking block: detect language of block content only - Otherwise (no blocks or multiple blocks): detect language of whole response - Returns match_reward (1.0) if language matches, no_match_reward (0.0) otherwise Changes: - Removed partial_reward and fallback_reward parameters (now just match/no-match) - Renamed full_reward to match_reward for clarity - Updated all 29 tests to match new behavior (all passing) - Updated README with clearer explanation of reward separation - Updated debug script with new expected rewards This separation of concerns allows each reward to specialize: - ThinkingReward: format enforcement - LanguageReward: language detection --- sandbox/grpo_language/README.md | 11 +- sandbox/grpo_language/debug_reward.py | 11 +- src/forge/data/rewards.py | 136 +++++++------------- tests/unit_tests/rl/test_language_reward.py | 52 ++++---- 4 files changed, 85 insertions(+), 125 deletions(-) diff --git a/sandbox/grpo_language/README.md b/sandbox/grpo_language/README.md index cfa97af40..f40b15c1a 100644 --- a/sandbox/grpo_language/README.md +++ b/sandbox/grpo_language/README.md @@ -44,11 +44,16 @@ You can use any of the config files from `apps/grpo/` (e.g., `qwen3_1_7b.yaml`, 1. The model receives a math problem and is instructed to use `<思考>` tags for reasoning 2. During training, the model generates responses with thinking blocks 3. Three rewards are computed: - - Math correctness (did it get the right answer?) - - Thinking usage (did it use `<思考>` tags properly?) - - Language usage (did it think in Japanese?) + - **MathReward**: Did it get the right answer? + - **ThinkingReward**: Did it use `<思考>` tags properly? (single block = full reward, multiple blocks = partial reward) + - **LanguageReward**: Did it use the target language? Detection strategy: + - If exactly one thinking block: detect language of block content only + - Otherwise (no blocks or multiple blocks): detect language of whole response + - Returns match_reward (1.0) if detected language matches target, no_match_reward (0.0) otherwise 4. The model is trained to maximize all three rewards +**Note**: ThinkingReward enforces format (single vs multiple blocks), while LanguageReward focuses purely on language detection. This separation of concerns allows each reward to specialize in one aspect of the desired behavior. + ## Configuration ### Target Language diff --git a/sandbox/grpo_language/debug_reward.py b/sandbox/grpo_language/debug_reward.py index b79afe2a6..66e44342b 100644 --- a/sandbox/grpo_language/debug_reward.py +++ b/sandbox/grpo_language/debug_reward.py @@ -80,9 +80,10 @@ print(" No content to analyze") print("\n" + "=" * 80) -print("Expected rewards:") -print(" full_reward (1.0): Single Japanese thinking block") -print(" partial_reward (0.5): Multiple Japanese thinking blocks") -print(" fallback_reward (0.2): No blocks but Japanese response text") -print(" no_match_reward (0.0): Wrong language") +print("Expected rewards (simplified logic):") +print(" match_reward (1.0): Detected language matches target (ja)") +print(" no_match_reward (0.0): Detected language doesn't match target") +print("\nDetection strategy:") +print(" - Single thinking block: detect language of block content only") +print(" - Multiple blocks or no blocks: detect language of whole response") print("=" * 80) diff --git a/src/forge/data/rewards.py b/src/forge/data/rewards.py index a966d5c86..91ed7fea5 100644 --- a/src/forge/data/rewards.py +++ b/src/forge/data/rewards.py @@ -96,17 +96,20 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo class LanguageReward: - """Reward class for evaluating the language used in thinking tags. + """Reward class for evaluating the language used in responses. - This reward uses langid to detect the language of text within thinking blocks - and rewards responses that use the target language. + This reward uses langid to detect the language and rewards responses that use + the target language. The detection strategy depends on the format: + - If exactly one thinking block: detect language of the block content + - Otherwise (no blocks or multiple blocks): detect language of whole response + + Note: Format enforcement (single vs multiple blocks) is handled by ThinkingReward. + This reward focuses purely on language detection. Args: target_language: ISO 639-1 language code (e.g., 'en', 'ja', 'zh', 'es') - full_reward: Reward when language matches and format is correct (single block) - partial_reward: Reward when language matches but format is wrong (multiple blocks) - fallback_reward: Reward when no valid blocks but response text is in target language - no_match_reward: Reward when language doesn't match + match_reward: Reward when detected language matches target (default: 1.0) + no_match_reward: Reward when language doesn't match (default: 0.0) tag: Tag name to use (default "思考" for multilingual, can use "think", etc.) debug: If True, print debug samples showing model outputs and detected language debug_sample_rate: Fraction of calls to debug (e.g., 0.1 = 10% of calls) @@ -117,18 +120,14 @@ class LanguageReward: def __init__( self, target_language: str = "en", - full_reward: float = 1.0, - partial_reward: float = 0.5, - fallback_reward: float = 0.2, + match_reward: float = 1.0, no_match_reward: float = 0.0, tag: str = "思考", debug: bool = False, debug_sample_rate: float = 0.1, ): self.target_language = target_language - self.full_reward = full_reward - self.partial_reward = partial_reward - self.fallback_reward = fallback_reward + self.match_reward = match_reward self.no_match_reward = no_match_reward self.tag = tag self.debug = debug @@ -138,7 +137,6 @@ def __init__( self._THINK_BLOCK_RE = re.compile( rf"<\s*{re.escape(tag)}\s*>(.*?)<\s*/\s*{re.escape(tag)}\s*>", re.DOTALL ) - self._TAG_PATTERN = rf"<\s*/?\s*{re.escape(tag)}\s*>" # Lazy import langid with helpful error message try: @@ -152,18 +150,19 @@ def __init__( ) from None def __call__(self, prompt: str, response: str, target: str | None = None) -> float: - """Compute language reward based on thinking block content. + """Compute language reward based on detected language. + + Detection strategy: + - If exactly one thinking block: detect language of block content + - Otherwise: detect language of whole response Args: prompt: The input prompt (unused but kept for signature consistency) - response: The model response containing tags + response: The model response target: Optional target string (unused but kept for signature consistency) Returns: - full_reward if language matches and exactly one thinking block is found, - partial_reward if language matches but multiple thinking blocks found, - fallback_reward if no valid blocks but response text is in target language, - no_match_reward otherwise (wrong language) + match_reward if detected language matches target, no_match_reward otherwise """ # Increment counter for sampling self._debug_counter += 1 @@ -174,89 +173,52 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo ) if not response: - return self.no_match_reward - - # Extract all thinking blocks - matches = self._THINK_BLOCK_RE.findall(response) - - # If no thinking blocks found, check if response text is in target language - if len(matches) == 0: - # Remove any partial tags that might exist - response_text = re.sub(self._TAG_PATTERN, "", response).strip() - - if not response_text: - if should_debug: - print( - f"\n[LanguageReward] Empty response | Reward: {self.no_match_reward}" - ) - return self.no_match_reward - - # Detect language of general response - detected_lang, confidence = self._langid.classify(response_text) - if should_debug: - sample = response[:150].replace("\n", " ") print( - f"\n[LanguageReward] No thinking blocks found (FALLBACK mode)" - f"\n Target: {self.target_language} | Detected: {detected_lang} | " - f"Confidence: {confidence:.2f}" - f"\n Sample: {sample}..." + f"\n[LanguageReward] Empty response | Reward: {self.no_match_reward}" ) - - # Give fallback reward if response is in target language - if detected_lang == self.target_language: - if should_debug: - print( - f" → Reward: {self.fallback_reward} (fallback, correct language)" - ) - return self.fallback_reward - - if should_debug: - print(f" → Reward: {self.no_match_reward} (wrong language)") return self.no_match_reward - # Concatenate all thinking blocks for language detection - thinking_content = " ".join(matches) + # Extract all thinking blocks + matches = self._THINK_BLOCK_RE.findall(response) + + # Determine what text to analyze + if len(matches) == 1: + # Single block: detect language of block content only + text_to_analyze = matches[0].strip() + detection_mode = "single block" + else: + # No blocks or multiple blocks: detect language of whole response + text_to_analyze = response.strip() + detection_mode = f"{len(matches)} blocks, using whole response" # Remove extra whitespace - thinking_content = re.sub(r"\s+", " ", thinking_content).strip() + text_to_analyze = re.sub(r"\s+", " ", text_to_analyze).strip() - if not thinking_content: + if not text_to_analyze: if should_debug: - print( - f"\n[LanguageReward] Empty thinking blocks | Reward: {self.no_match_reward}" - ) + print(f"\n[LanguageReward] Empty text | Reward: {self.no_match_reward}") return self.no_match_reward # Detect language using langid - detected_lang, confidence = self._langid.classify(thinking_content) + detected_lang, confidence = self._langid.classify(text_to_analyze) + + # Check if language matches target + reward = ( + self.match_reward + if detected_lang == self.target_language + else self.no_match_reward + ) if should_debug: - sample = thinking_content[:150].replace("\n", " ") + sample = text_to_analyze[:150].replace("\n", " ") + match_symbol = "✓" if detected_lang == self.target_language else "✗" print( - f"\n[LanguageReward] Found {len(matches)} thinking block(s)" + f"\n[LanguageReward] Detection mode: {detection_mode}" f"\n Target: {self.target_language} | Detected: {detected_lang} | " f"Confidence: {confidence:.2f}" - f"\n Thinking sample: {sample}..." + f"\n Sample: {sample}..." + f"\n → Reward: {reward} {match_symbol}" ) - # Check if language matches target - if detected_lang == self.target_language: - # Full reward for correct format (single block) - if len(matches) == 1: - if should_debug: - print( - f" → Reward: {self.full_reward} (single block, correct language) ✓" - ) - return self.full_reward - # Partial reward for wrong format (multiple blocks) but correct language - else: - if should_debug: - print( - f" → Reward: {self.partial_reward} (multiple blocks, correct language)" - ) - return self.partial_reward - - if should_debug: - print(f" → Reward: {self.no_match_reward} (wrong language) ✗") - return self.no_match_reward + return reward diff --git a/tests/unit_tests/rl/test_language_reward.py b/tests/unit_tests/rl/test_language_reward.py index 1df1eb460..423ba4829 100644 --- a/tests/unit_tests/rl/test_language_reward.py +++ b/tests/unit_tests/rl/test_language_reward.py @@ -20,9 +20,7 @@ def setUp(self): self.reward_ja = LanguageReward(target_language="ja") self.custom_reward = LanguageReward( target_language="ja", - full_reward=0.9, - partial_reward=0.6, - fallback_reward=0.3, + match_reward=0.9, no_match_reward=0.1, ) @@ -30,24 +28,18 @@ def test_init_default_values(self): """Test LanguageReward initialization with default values.""" reward = self.LanguageReward() self.assertEqual(reward.target_language, "en") - self.assertEqual(reward.full_reward, 1.0) - self.assertEqual(reward.partial_reward, 0.5) - self.assertEqual(reward.fallback_reward, 0.2) + self.assertEqual(reward.match_reward, 1.0) self.assertEqual(reward.no_match_reward, 0.0) def test_init_custom_values(self): """Test LanguageReward initialization with custom values.""" reward = self.LanguageReward( target_language="ja", - full_reward=0.9, - partial_reward=0.6, - fallback_reward=0.3, + match_reward=0.9, no_match_reward=0.1, ) self.assertEqual(reward.target_language, "ja") - self.assertEqual(reward.full_reward, 0.9) - self.assertEqual(reward.partial_reward, 0.6) - self.assertEqual(reward.fallback_reward, 0.3) + self.assertEqual(reward.match_reward, 0.9) self.assertEqual(reward.no_match_reward, 0.1) def test_init_missing_langid(self): @@ -130,13 +122,13 @@ def test_call_with_no_thinking_tags(self): result = self.reward_en( "prompt", "This is just a regular response without any thinking tags." ) - # No thinking blocks but response is in English, should get fallback reward - self.assertEqual(result, 0.2) + # No thinking blocks -> detect whole response, English detected -> match_reward + self.assertEqual(result, 1.0) def test_call_with_no_thinking_tags_wrong_language(self): """Test __call__ with response containing no thinking tags and wrong language.""" result = self.reward_en("prompt", "これは日本語の応答です。タグはありません。") - # No thinking blocks and wrong language, should get no_match_reward + # No thinking blocks -> detect whole response, Japanese detected -> no_match_reward self.assertEqual(result, 0.0) def test_call_with_empty_thinking_block(self): @@ -167,26 +159,26 @@ def test_call_with_whitespace_in_tags(self): self.assertEqual(result, 1.0) def test_call_multiple_thinking_blocks(self): - """Test __call__ with multiple thinking blocks (wrong format but correct language).""" + """Test __call__ with multiple thinking blocks - detects whole response language.""" response = """ <思考>First thought in English. Some text in between. <思考>Second thought also in English. """ result = self.reward_en("prompt", response) - # Multiple blocks = wrong format, but language is correct, should return partial_reward - self.assertEqual(result, 0.5) + # Multiple blocks -> detect whole response, English detected -> match_reward + self.assertEqual(result, 1.0) def test_call_multiple_thinking_blocks_mixed_languages(self): - """Test __call__ with multiple thinking blocks in different languages (wrong format).""" + """Test __call__ with multiple thinking blocks in different languages.""" response = """ <思考>First thought in English with lots of content here. <思考>これは短い日本語。 """ result = self.reward_en("prompt", response) - # Multiple blocks with mixed languages - langid will detect dominant language - # Should return either partial_reward (if detects English) or no_match_reward (if detects Japanese) - self.assertIn(result, [0.0, 0.5]) + # Multiple blocks -> detect whole response, langid will detect dominant language + # Should return match_reward (1.0) if English dominant, or no_match_reward (0.0) if not + self.assertIn(result, [0.0, 1.0]) def test_call_multiline_thinking_block(self): """Test __call__ with multiline thinking blocks.""" @@ -215,13 +207,13 @@ def test_call_with_target_parameter(self): result = self.reward_en("prompt", response, target="some target") self.assertEqual(result, 1.0) - # Longer English text without tags should get fallback reward + # English text without tags -> detect whole response -> match_reward result = self.reward_en( "prompt", "This is a response without thinking tags but in English language.", target="some target", ) - self.assertEqual(result, 0.2) + self.assertEqual(result, 1.0) def test_call_custom_reward_values(self): """Test __call__ with custom reward values.""" @@ -231,12 +223,12 @@ def test_call_custom_reward_values(self): response_en = "<思考>This is English." response_none = "" - # Test custom full reward (single block, correct language) + # Test custom match reward (single block, correct language) self.assertEqual(self.custom_reward("prompt", response_ja_single), 0.9) - # Test custom partial reward (multiple blocks, correct language) - self.assertEqual(self.custom_reward("prompt", response_ja_multiple), 0.6) - # Test custom fallback reward (no blocks, correct language) - self.assertEqual(self.custom_reward("prompt", response_ja_no_tags), 0.3) + # Test custom match reward (multiple blocks -> whole response, correct language) + self.assertEqual(self.custom_reward("prompt", response_ja_multiple), 0.9) + # Test custom match reward (no blocks -> whole response, correct language) + self.assertEqual(self.custom_reward("prompt", response_ja_no_tags), 0.9) # Test custom no_match reward (wrong language) self.assertEqual(self.custom_reward("prompt", response_en), 0.1) # Test empty response @@ -245,7 +237,7 @@ def test_call_custom_reward_values(self): def test_call_zero_custom_values(self): """Test __call__ with zero custom values.""" zero_reward = self.LanguageReward( - target_language="en", full_reward=0.0, no_match_reward=0.0 + target_language="en", match_reward=0.0, no_match_reward=0.0 ) result = zero_reward("prompt", "<思考>This is English.") self.assertEqual(result, 0.0) From 0ed798ccbd6410c1b30b7b8edbea32dd771b1527 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 31 Oct 2025 16:52:51 -0700 Subject: [PATCH 11/24] Add langid to dev dependencies for CI Added langid to [project.optional-dependencies] dev section so CI tests can run without failing on import. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 8460b5b78..bb66d0191 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dev = [ "anyio", "pytest-asyncio", "multiprocess", + "langid", ] docs = [ "sphinx==7.2.6", From 5a3193e07474177fdd8cb96b558724a1c55a49b6 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 31 Oct 2025 16:57:46 -0700 Subject: [PATCH 12/24] Remove debug script Removed sandbox/grpo_language/debug_reward.py and updated TROUBLESHOOTING.md to remove references to it. --- sandbox/grpo_language/TROUBLESHOOTING.md | 15 +--- sandbox/grpo_language/debug_reward.py | 89 ------------------------ 2 files changed, 1 insertion(+), 103 deletions(-) delete mode 100644 sandbox/grpo_language/debug_reward.py diff --git a/sandbox/grpo_language/TROUBLESHOOTING.md b/sandbox/grpo_language/TROUBLESHOOTING.md index 4f077a2ae..db69628b2 100644 --- a/sandbox/grpo_language/TROUBLESHOOTING.md +++ b/sandbox/grpo_language/TROUBLESHOOTING.md @@ -107,20 +107,7 @@ Check these metrics in Weights & Biases: - `reward/evaluate_response/avg_MathReward_reward` - should stay reasonably high - `reward/evaluate_response/avg_ThinkingReward_reward` - should increase quickly -### 5. Quick Debug Test - -Run the debug script to verify the reward function works: -```bash -python sandbox/grpo_language/debug_reward.py -``` - -Expected output: -- Japanese text → reward 1.0 -- English text → reward 0.0 -- Multiple Japanese blocks → reward 0.5 -- No blocks but Japanese response → reward 0.2 - -### 6. Alternative: Start with English, then transition +### 5. Alternative: Start with English, then transition If Japanese isn't working, you could: diff --git a/sandbox/grpo_language/debug_reward.py b/sandbox/grpo_language/debug_reward.py deleted file mode 100644 index 66e44342b..000000000 --- a/sandbox/grpo_language/debug_reward.py +++ /dev/null @@ -1,89 +0,0 @@ -#!/usr/bin/env python -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -"""Debug script to test LanguageReward behavior.""" - -from forge.data.rewards import LanguageReward - -# Create reward for Japanese -reward = LanguageReward(target_language="ja") - -# Test cases mimicking what the model might generate -test_cases = [ - # Case 1: Perfect - Japanese in single thinking block - ("<思考>これは数学の問題です。2+2=4です。4", "Perfect Japanese"), - # Case 2: English thinking (most likely during training) - ( - "<思考>This is a math problem. 2+2=4.4", - "English thinking", - ), - # Case 3: No thinking blocks at all - ("The answer is 4.4", "No thinking blocks"), - # Case 4: Empty thinking blocks - ("<思考>4", "Empty thinking block"), - # Case 5: Multiple thinking blocks in Japanese - ( - "<思考>最初の考え。<思考>次の考え。4", - "Multiple Japanese blocks", - ), - # Case 6: Just the answer, no thinking - ("4", "Just answer tag"), - # Case 7: Thinking with mostly numbers/symbols - ("<思考>2 + 2 = 44", "Mostly numbers"), - # Case 8: Mixed English and Japanese - ("<思考>Let me think... これは簡単です。4", "Mixed languages"), -] - -print("=" * 80) -print("LanguageReward Debug Output (target_language='ja')") -print("=" * 80) - -for response, description in test_cases: - score = reward(prompt="", response=response, target=None) - - import re - - # Try to detect what langid thinks - import langid - - # Extract thinking content if exists - think_match = re.findall(r"<\s*思考\s*>(.*?)<\s*/\s*思考\s*>", response, re.DOTALL) - - if think_match: - content = " ".join(think_match) - detected_lang, confidence = langid.classify(content) - print(f"\n{description}:") - print(f" Response: {response[:60]}...") - print(f" Reward: {score}") - print(f" Thinking blocks found: {len(think_match)}") - print(f" Detected language: {detected_lang} (confidence: {confidence:.3f})") - else: - # Check fallback - response_text = re.sub(r"<\s*/?\s*思考\s*>", "", response).strip() - if response_text: - detected_lang, confidence = langid.classify(response_text) - print(f"\n{description}:") - print(f" Response: {response[:60]}...") - print(f" Reward: {score}") - print(" Thinking blocks found: 0") - print( - f" Fallback detection on response text: {detected_lang} (confidence: {confidence:.3f})" - ) - else: - print(f"\n{description}:") - print(f" Response: {response[:60]}...") - print(f" Reward: {score}") - print(" No content to analyze") - -print("\n" + "=" * 80) -print("Expected rewards (simplified logic):") -print(" match_reward (1.0): Detected language matches target (ja)") -print(" no_match_reward (0.0): Detected language doesn't match target") -print("\nDetection strategy:") -print(" - Single thinking block: detect language of block content only") -print(" - Multiple blocks or no blocks: detect language of whole response") -print("=" * 80) From 93a65b2894808ac262071bb7198ea3dde1cd94d2 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 31 Oct 2025 17:00:04 -0700 Subject: [PATCH 13/24] Clarify why English training won't work in TROUBLESHOOTING Replaced misleading "Start with English" section with explanation of why English training won't work (already extensively pre-trained) and why Japanese is the right choice (novel combination, clear RL signal). --- sandbox/grpo_language/TROUBLESHOOTING.md | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/sandbox/grpo_language/TROUBLESHOOTING.md b/sandbox/grpo_language/TROUBLESHOOTING.md index db69628b2..a8765ce97 100644 --- a/sandbox/grpo_language/TROUBLESHOOTING.md +++ b/sandbox/grpo_language/TROUBLESHOOTING.md @@ -107,21 +107,16 @@ Check these metrics in Weights & Biases: - `reward/evaluate_response/avg_MathReward_reward` - should stay reasonably high - `reward/evaluate_response/avg_ThinkingReward_reward` - should increase quickly -### 5. Alternative: Start with English, then transition +### 5. Why Not Train with English? -If Japanese isn't working, you could: +Training with English thinking won't work well because: +- Models are already extensively trained on GSM8K and similar datasets with English thinking +- There's little room for improvement on English math reasoning +- The RL signal would be weak (model already knows how to do this) -1. Train first with English to get good math performance -2. Then fine-tune with Japanese language reward +**That's why we use Japanese** - it provides a novel combination of math reasoning + non-English thinking that the model hasn't been extensively pre-trained on, giving clear RL signal for improvement. -Change line 327 to: -```python -LanguageReward(target_language="en") # Start with English -``` - -Once math rewards are good, switch to `"ja"` and continue training. - -### 7. Nuclear Option: Much Stronger Prompt +### 6. Nuclear Option: Much Stronger Prompt If nothing else works, try this very explicit prompt: ```python From f72be7f259147c086c3867d24a9e9c16d96771cf Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 31 Oct 2025 17:10:35 -0700 Subject: [PATCH 14/24] Add unit test for ThinkingReward custom tag Added test_custom_tag to verify that ThinkingReward correctly uses the custom tag passed in during initialization. The test confirms: - Responses with custom tag get full reward - Responses with default tag get no reward when custom tag is used All 26 ThinkingReward tests passing. --- tests/unit_tests/rl/test_thinking_reward.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/unit_tests/rl/test_thinking_reward.py b/tests/unit_tests/rl/test_thinking_reward.py index b95823e9a..10b7bf38e 100644 --- a/tests/unit_tests/rl/test_thinking_reward.py +++ b/tests/unit_tests/rl/test_thinking_reward.py @@ -203,6 +203,19 @@ def test_call_very_long_thinking_block(self): result = self.reward("prompt", f"{long_content}") self.assertEqual(result, 1.0) + def test_custom_tag(self): + """Test that ThinkingReward uses the custom tag passed in.""" + # Create reward with custom Japanese tag + custom_tag_reward = ThinkingReward(tag="思考") + + # Response with custom tag should get full reward + result = custom_tag_reward("prompt", "<思考>This is my reasoning") + self.assertEqual(result, 1.0) + + # Response with default "think" tag should get no reward + result = custom_tag_reward("prompt", "This is my reasoning") + self.assertEqual(result, 0.0) + if __name__ == "__main__": unittest.main() From 6186f9f310da1d37933f76f2d1eef348284888ae Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 31 Oct 2025 17:20:04 -0700 Subject: [PATCH 15/24] Bump LanguageReward match_reward to 2.0 Increased match_reward from default 1.0 to 2.0 to give more weight to language matching in the multi-objective reward. --- sandbox/grpo_language/main.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sandbox/grpo_language/main.py b/sandbox/grpo_language/main.py index c27c55638..9fe902233 100644 --- a/sandbox/grpo_language/main.py +++ b/sandbox/grpo_language/main.py @@ -329,7 +329,11 @@ async def main(cfg: DictConfig): MathReward(), ThinkingReward(tag="思考"), # Use Japanese tag LanguageReward( - target_language="ja", tag="思考", debug=True, debug_sample_rate=0.1 + target_language="ja", + tag="思考", + match_reward=2.0, + debug=True, + debug_sample_rate=0.1, ), # Japanese language reward with debug ] ), From c640d379c5fd30a3b43a37623132ee020017905b Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 31 Oct 2025 18:52:04 -0700 Subject: [PATCH 16/24] Set KL divergence coefficient to zero in loss function Changed beta parameter from 0.1 to 0.0 in simple_grpo_loss to remove the KL divergence penalty term from the loss. --- sandbox/grpo_language/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sandbox/grpo_language/main.py b/sandbox/grpo_language/main.py index 9fe902233..597fc31e4 100644 --- a/sandbox/grpo_language/main.py +++ b/sandbox/grpo_language/main.py @@ -125,7 +125,7 @@ def simple_grpo_loss( ref_logprobs: torch.Tensor, advantages: torch.Tensor, padding_mask: torch.Tensor, - beta: float = 0.1, + beta: float = 0.0, ) -> torch.Tensor: logprobs: torch.Tensor = compute_logprobs(logits, response) kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 From 7fde86dcf494cf43883484651b7d3931e9ef13b9 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 31 Oct 2025 20:50:27 -0700 Subject: [PATCH 17/24] Change KL divergence coefficient to 1e-3 Changed beta parameter from 0.0 to 1e-3 in simple_grpo_loss to add a small KL divergence penalty. --- sandbox/grpo_language/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sandbox/grpo_language/main.py b/sandbox/grpo_language/main.py index 597fc31e4..d5d5ded64 100644 --- a/sandbox/grpo_language/main.py +++ b/sandbox/grpo_language/main.py @@ -125,7 +125,7 @@ def simple_grpo_loss( ref_logprobs: torch.Tensor, advantages: torch.Tensor, padding_mask: torch.Tensor, - beta: float = 0.0, + beta: float = 1e-3, ) -> torch.Tensor: logprobs: torch.Tensor = compute_logprobs(logits, response) kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 From ffb6c432f3bf149f17fa596c6dfe2fb697d5ac5b Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 31 Oct 2025 23:49:17 -0700 Subject: [PATCH 18/24] Change KL divergence coefficient to 1e-4 Changed beta parameter from 1e-3 to 1e-4 in simple_grpo_loss. --- sandbox/grpo_language/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sandbox/grpo_language/main.py b/sandbox/grpo_language/main.py index d5d5ded64..5b89e8e7f 100644 --- a/sandbox/grpo_language/main.py +++ b/sandbox/grpo_language/main.py @@ -125,7 +125,7 @@ def simple_grpo_loss( ref_logprobs: torch.Tensor, advantages: torch.Tensor, padding_mask: torch.Tensor, - beta: float = 1e-3, + beta: float = 1e-4, ) -> torch.Tensor: logprobs: torch.Tensor = compute_logprobs(logits, response) kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 From 7ffa20e070d0c76c4773798ddf0e9cb3085faca0 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 1 Nov 2025 21:44:23 -0700 Subject: [PATCH 19/24] Enable multi-epoch training in sandbox/grpo_language app Apply the multi-epoch training fix from fix-multi-epoch-training branch: - Add epoch tracking to DatasetActor - Store base dataset for iterator reuse - Restart iterator with set_epoch() when StopIteration occurs - Add epoch completion logging and metrics This allows training to continue beyond the first epoch instead of stopping when the dataset is exhausted. --- sandbox/grpo_language/main.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/sandbox/grpo_language/main.py b/sandbox/grpo_language/main.py index 5b89e8e7f..a86fd7b97 100644 --- a/sandbox/grpo_language/main.py +++ b/sandbox/grpo_language/main.py @@ -213,6 +213,7 @@ class DatasetActor(ForgeActor): @endpoint def setup(self): self._tokenizer = get_tokenizer(self.model) + self._epoch = 0 def gsm8k_transform(sample): system_prompt = """ @@ -239,12 +240,12 @@ def gsm8k_transform(sample): formatted_target = target.split("#### ")[1] return {"request": formatted_request, "target": formatted_target} - ds = load_dataset( + self._base_dataset = load_dataset( self.path, self.revision, split=self.data_split, streaming=self.streaming ) - ds = ds.map(gsm8k_transform) - ds = ds.shuffle() - self._iterator = iter(ds) + self._base_dataset = self._base_dataset.map(gsm8k_transform) + self._base_dataset = self._base_dataset.shuffle() + self._iterator = iter(self._base_dataset) @endpoint async def sample(self) -> dict[str, str] | None: @@ -260,7 +261,15 @@ async def sample(self) -> dict[str, str] | None: return sample except StopIteration: - return None + # Restart iterator for next epoch + self._epoch += 1 + print( + f"Dataset epoch {self._epoch - 1} completed. Starting epoch {self._epoch}" + ) + record_metric("dataset/sample/epoch_completed", self._epoch, Reduce.LAST) + self._base_dataset.set_epoch(self._epoch) + self._iterator = iter(self._base_dataset) + return await self.sample() @endpoint async def pad_token(self): From 1bf3cca6996511630bd6d81bf5ec62f1cb4904db Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 1 Nov 2025 21:49:13 -0700 Subject: [PATCH 20/24] Fix recursive endpoint call - use while loop instead Endpoints cannot be called recursively. Changed the StopIteration handler from `return await self.sample()` to using a while loop that continues to retry getting the next sample after restarting the iterator. --- sandbox/grpo_language/main.py | 43 +++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/sandbox/grpo_language/main.py b/sandbox/grpo_language/main.py index a86fd7b97..147bff7d5 100644 --- a/sandbox/grpo_language/main.py +++ b/sandbox/grpo_language/main.py @@ -249,27 +249,30 @@ def gsm8k_transform(sample): @endpoint async def sample(self) -> dict[str, str] | None: - try: - sample = next(self._iterator) - - record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM) - record_metric( - "dataset/sample/avg_sample_len", - len(sample["request"]), - Reduce.MEAN, - ) + while True: + try: + sample = next(self._iterator) + + record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM) + record_metric( + "dataset/sample/avg_sample_len", + len(sample["request"]), + Reduce.MEAN, + ) - return sample - except StopIteration: - # Restart iterator for next epoch - self._epoch += 1 - print( - f"Dataset epoch {self._epoch - 1} completed. Starting epoch {self._epoch}" - ) - record_metric("dataset/sample/epoch_completed", self._epoch, Reduce.LAST) - self._base_dataset.set_epoch(self._epoch) - self._iterator = iter(self._base_dataset) - return await self.sample() + return sample + except StopIteration: + # Restart iterator for next epoch + self._epoch += 1 + print( + f"Dataset epoch {self._epoch - 1} completed. Starting epoch {self._epoch}" + ) + record_metric( + "dataset/sample/epoch_completed", self._epoch, Reduce.LAST + ) + self._base_dataset.set_epoch(self._epoch) + self._iterator = iter(self._base_dataset) + # Continue loop to get next sample from restarted iterator @endpoint async def pad_token(self): From 7758b4807fa228194554e156eacd3fb8ae92eaa8 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 1 Nov 2025 21:54:11 -0700 Subject: [PATCH 21/24] Simplify multi-epoch fix - use return next() instead of while loop Instead of while True loop, simply return next(self._iterator) after restarting the iterator. This is cleaner and safer - if the dataset is truly empty, it will raise StopIteration without infinite recursion. --- sandbox/grpo_language/main.py | 43 ++++++++++++++++------------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/sandbox/grpo_language/main.py b/sandbox/grpo_language/main.py index 147bff7d5..db39a0cfd 100644 --- a/sandbox/grpo_language/main.py +++ b/sandbox/grpo_language/main.py @@ -249,30 +249,27 @@ def gsm8k_transform(sample): @endpoint async def sample(self) -> dict[str, str] | None: - while True: - try: - sample = next(self._iterator) - - record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM) - record_metric( - "dataset/sample/avg_sample_len", - len(sample["request"]), - Reduce.MEAN, - ) + try: + sample = next(self._iterator) - return sample - except StopIteration: - # Restart iterator for next epoch - self._epoch += 1 - print( - f"Dataset epoch {self._epoch - 1} completed. Starting epoch {self._epoch}" - ) - record_metric( - "dataset/sample/epoch_completed", self._epoch, Reduce.LAST - ) - self._base_dataset.set_epoch(self._epoch) - self._iterator = iter(self._base_dataset) - # Continue loop to get next sample from restarted iterator + record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM) + record_metric( + "dataset/sample/avg_sample_len", + len(sample["request"]), + Reduce.MEAN, + ) + + return sample + except StopIteration: + # Restart iterator for next epoch + self._epoch += 1 + print( + f"Dataset epoch {self._epoch - 1} completed. Starting epoch {self._epoch}" + ) + record_metric("dataset/sample/epoch_completed", self._epoch, Reduce.LAST) + self._base_dataset.set_epoch(self._epoch) + self._iterator = iter(self._base_dataset) + return next(self._iterator) @endpoint async def pad_token(self): From f71dbb645bb36d3473b7d68c3fc7c9753c353c76 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Mon, 3 Nov 2025 10:28:21 -0800 Subject: [PATCH 22/24] fix --- sandbox/grpo_language/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sandbox/grpo_language/main.py b/sandbox/grpo_language/main.py index db39a0cfd..b3bff910e 100644 --- a/sandbox/grpo_language/main.py +++ b/sandbox/grpo_language/main.py @@ -266,7 +266,7 @@ async def sample(self) -> dict[str, str] | None: print( f"Dataset epoch {self._epoch - 1} completed. Starting epoch {self._epoch}" ) - record_metric("dataset/sample/epoch_completed", self._epoch, Reduce.LAST) + record_metric("dataset/sample/epoch_completed", self._epoch, Reduce.MAX) self._base_dataset.set_epoch(self._epoch) self._iterator = iter(self._base_dataset) return next(self._iterator) From 735af9ac7dd48d5d8ccef7fa44af743cf41578b8 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Mon, 3 Nov 2025 10:29:15 -0800 Subject: [PATCH 23/24] change logging --- sandbox/grpo_language/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sandbox/grpo_language/main.py b/sandbox/grpo_language/main.py index b3bff910e..6059dc112 100644 --- a/sandbox/grpo_language/main.py +++ b/sandbox/grpo_language/main.py @@ -259,6 +259,7 @@ async def sample(self) -> dict[str, str] | None: Reduce.MEAN, ) + record_metric("dataset/sample/current_epoch", self._epoch, Reduce.MAX) return sample except StopIteration: # Restart iterator for next epoch @@ -266,7 +267,6 @@ async def sample(self) -> dict[str, str] | None: print( f"Dataset epoch {self._epoch - 1} completed. Starting epoch {self._epoch}" ) - record_metric("dataset/sample/epoch_completed", self._epoch, Reduce.MAX) self._base_dataset.set_epoch(self._epoch) self._iterator = iter(self._base_dataset) return next(self._iterator) From ef39e46d6c49b53566c4240fa44198514525bf73 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Mon, 3 Nov 2025 10:29:55 -0800 Subject: [PATCH 24/24] git mv --- {sandbox => tests/sandbox}/grpo_language/README.md | 0 {sandbox => tests/sandbox}/grpo_language/TROUBLESHOOTING.md | 0 {sandbox => tests/sandbox}/grpo_language/main.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename {sandbox => tests/sandbox}/grpo_language/README.md (100%) rename {sandbox => tests/sandbox}/grpo_language/TROUBLESHOOTING.md (100%) rename {sandbox => tests/sandbox}/grpo_language/main.py (100%) diff --git a/sandbox/grpo_language/README.md b/tests/sandbox/grpo_language/README.md similarity index 100% rename from sandbox/grpo_language/README.md rename to tests/sandbox/grpo_language/README.md diff --git a/sandbox/grpo_language/TROUBLESHOOTING.md b/tests/sandbox/grpo_language/TROUBLESHOOTING.md similarity index 100% rename from sandbox/grpo_language/TROUBLESHOOTING.md rename to tests/sandbox/grpo_language/TROUBLESHOOTING.md diff --git a/sandbox/grpo_language/main.py b/tests/sandbox/grpo_language/main.py similarity index 100% rename from sandbox/grpo_language/main.py rename to tests/sandbox/grpo_language/main.py