From 7e8e8f2a3cd4f3a5a05e360bdc47c7030a437789 Mon Sep 17 00:00:00 2001 From: joecummings Date: Wed, 5 Nov 2025 04:28:19 -0800 Subject: [PATCH 1/2] [WIP] On Policy Distillation (sync) --- apps/on-policy-distillation/main.py | 216 ++++++++++++++++++ .../qwen_0_6b_to_8b.yaml | 111 +++++++++ src/forge/actors/reference_model.py | 80 ++++++- src/forge/actors/trainer.py | 27 ++- 4 files changed, 412 insertions(+), 22 deletions(-) create mode 100644 apps/on-policy-distillation/main.py create mode 100644 apps/on-policy-distillation/qwen_0_6b_to_8b.yaml diff --git a/apps/on-policy-distillation/main.py b/apps/on-policy-distillation/main.py new file mode 100644 index 000000000..ce5bef000 --- /dev/null +++ b/apps/on-policy-distillation/main.py @@ -0,0 +1,216 @@ +import asyncio +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn.functional as F +import torchstore as ts +from datasets import load_dataset +from forge.actors.generator import Generator +from forge.actors.reference_model import ReferenceModel +from forge.actors.trainer import RLTrainer +from forge.controller.provisioner import init_provisioner, shutdown +from forge.data_models.completion import Completion +from forge.observability.metric_actors import get_or_create_metric_logger +from forge.util.config import parse +from forge.util.ops import compute_logprobs +from omegaconf import DictConfig +from vllm.transformers_utils.tokenizer import get_tokenizer + + +@dataclass +class Trajectory: + pad_id: int + request_len: int + response_len: int + # Processed data + completion: Completion | None = None + teacher_logprobs: torch.Tensor | None = None + + @property + def request_tensor(self) -> torch.Tensor: + tensor: torch.Tensor = self.completion.prompt_ids.to(torch.long) + if tensor.shape[0] < self.request_len: # left pad + diff = self.request_len - tensor.shape[0] + tensor = F.pad(tensor, (diff, 0), value=self.pad_id) + elif tensor.shape[0] > self.request_len: # truncate + tensor = tensor[-self.request_len :] + return tensor + + @property + def response_tensor(self) -> torch.Tensor: + tensor: torch.Tensor = self.completion.token_ids.to(torch.long) + if tensor.shape[0] < self.response_len: # right pad + diff = self.response_len - tensor.shape[0] + tensor = F.pad(tensor, (0, diff), value=self.pad_id) + elif tensor.shape[0] > self.response_len: # truncate + tensor = tensor[: self.response_len] + return tensor + + +def collate( + batches: list[list[Trajectory]], +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + inputs = [] + targets = [] + for batch in batches: + request = [t.request_tensor for t in batch] + request = torch.stack(request) + + response = [t.response_tensor for t in batch] + response = torch.stack(response) + + teacher_logprobs = [t.teacher_logprobs for t in batch] + teacher_logprobs = torch.stack(teacher_logprobs) + + pad_id = batch[0].pad_id + padding_mask = response != pad_id + + input = {"tokens": torch.cat([request, response], dim=1)} + target = { + "response": response, + "teacher_logprobs": teacher_logprobs, + "padding_mask": padding_mask, + } + inputs.append(input) + targets.append(target) + return inputs, targets + + +def importance_sampling_loss( + logits: torch.Tensor, + response: torch.Tensor, + teacher_logprobs: torch.Tensor, + padding_mask: torch.Tensor, + **kwargs, +) -> torch.Tensor: + student_logprobs = compute_logprobs(logits, response) + reverse_kl = -(student_logprobs - teacher_logprobs) + prob_ratio = torch.exp(teacher_logprobs - student_logprobs) + per_token_loss = prob_ratio * reverse_kl + + # Apply mask and compute mean over valid tokens + masked_loss = per_token_loss * padding_mask + num_valid_tokens = padding_mask.sum(dim=1, keepdim=True).clamp(min=1.0) + loss_per_sequence = masked_loss.sum(dim=1, keepdim=True) / num_valid_tokens + loss = loss_per_sequence.mean() + + return loss + + +async def main(cfg: DictConfig): + train_batch_size = cfg.train_batch_size + max_steps = cfg.trainer.training.steps + max_req_tokens = cfg.max_req_tokens + max_res_tokens = cfg.max_res_tokens + + provisioner = await init_provisioner() + mlogger = await get_or_create_metric_logger(process_name="Controller") + await mlogger.init_backends.call_one( + { + "wandb": {"project": "opd-v0", "logging_mode": "global_reduce"}, + "console": {"logging_mode": "global_reduce"}, + } + ) + student_trainer, student_generator, teacher = await asyncio.gather( + RLTrainer.options(**cfg.services.trainer).as_actor( + **cfg.trainer, loss=importance_sampling_loss + ), + Generator.options(**cfg.services.student_generator).as_service( + **cfg.student_generator + ), + ReferenceModel.options(**cfg.services.teacher).as_service(**cfg.teacher), + ) + + # Setup torchstore for weight management + trainer_num_procs = cfg.services.trainer["procs"] + trainer_host_mesh_name = cfg.services.trainer["mesh_name"] + trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name) + await ts.initialize( + mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}), + strategy=ts.LocalRankStrategy(), + ) + + # Load dataset + tokenizer = get_tokenizer(cfg.student_model) + pad_id = tokenizer.pad_token_id + dataset = load_dataset(cfg.dataset.path, split=cfg.dataset.get("split", "train")) + dataset = dataset.filter(lambda x: x["domain"] == cfg.dataset["domain"]) + dataset_iter = iter(dataset) + + print("All services initialized successfully!") + + step = 0 + for epoch in range(max_steps): + if step >= max_steps: + break + + # Collect rollout + trajectories = [] + while len(trajectories) < train_batch_size: + try: + sample = next(dataset_iter) + # Extract the human prompt from OpenThoughts format + conversations = sample.get("conversations", []) + if conversations and len(conversations) > 0: + prompt = conversations[0].get("value", "") + else: + prompt = sample.get("prompt", sample.get("text", "")) + + print(f"Starting request with prompt: {prompt}") + completions = await student_generator.generate.route(prompt) + + for completion in completions: + # Create trajectory with raw completion + trajectory = Trajectory( + pad_id=pad_id, + request_len=max_req_tokens, + response_len=max_res_tokens, + completion=completion, + ) + + # Build padded input for teacher using trajectory properties + input_ids = torch.cat( + [ + trajectory.request_tensor.unsqueeze(0), + trajectory.response_tensor.unsqueeze(0), + ], + dim=1, + ) + + teacher_logprobs = await teacher.forward.route( + input_ids, max_req_tokens, return_logprobs=True + ) + + trajectory.teacher_logprobs = teacher_logprobs + trajectories.append(trajectory) + except StopIteration: + print("Dataset exhausted, resetting iterator") + dataset_iter = iter(dataset) + + # Train on collected trajectories + trajectories = [ + trajectories[i::train_batch_size] for i in range(train_batch_size) + ] + inputs, targets = collate(trajectories) + await student_trainer.train_step.call(inputs, targets) + + step += 1 + + # Push weights to student generator + await student_trainer.push_weights.call(step) + await student_generator.update_weights.fanout(step) + + await mlogger.flush.call_one(step) + + print(f"Training completed after {step} steps") + await shutdown() + + +if __name__ == "__main__": + + @parse + def _main(cfg): + asyncio.run(main(cfg)) + + _main() diff --git a/apps/on-policy-distillation/qwen_0_6b_to_8b.yaml b/apps/on-policy-distillation/qwen_0_6b_to_8b.yaml new file mode 100644 index 000000000..2df9c5196 --- /dev/null +++ b/apps/on-policy-distillation/qwen_0_6b_to_8b.yaml @@ -0,0 +1,111 @@ +# On-Policy Distillation: Qwen 0.6B (student) learning from Qwen 8B (teacher) +# >>> python -m apps.on-policy-distillation.main --config apps/on-policy-distillation/qwen_0_6b_to_8b.yaml + +# Global configuration +train_batch_size: 4 # Number of trajectories per training step +max_req_tokens: 512 +max_res_tokens: 65536 +student_model: "Qwen/Qwen3-1.7B" +teacher_model: "Qwen/Qwen3-8B" + +# Dataset configuration +dataset: + path: "open-thoughts/OpenThoughts3-1.2M" + split: "train" + domain: "math" + +# Student Generator configuration (inference model) +student_generator: + engine_args: + model: ${student_model} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + enforce_eager: false + sampling_params: + n: 2 # Single response per prompt + max_tokens: ${max_res_tokens} + temperature: 1.0 + top_p: 0.95 + +# Student Trainer configuration +trainer: + model: + name: qwen3 + flavor: 1.7B + hf_assets_path: hf://${student_model} + optimizer: + name: AdamW + lr: 5e-5 # Higher LR for distillation + eps: 1e-8 + lr_scheduler: + warmup_steps: 10 + training: + local_batch_size: ${train_batch_size} # Per-device batch size + seq_len: 66048 + max_norm: 1.0 + steps: 10000 + dtype: bfloat16 + gc_freq: 1 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 2 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + disable_loss_parallel: true + checkpoint: + enable: true + # folder: ./checkpoint_student + initial_load_path: hf://${student_model} + initial_load_in_hf: true + last_save_in_hf: true + interval: 500 + async_mode: "disabled" + activation_checkpoint: + mode: selective + selective_ac_option: op + +# Teacher model configuration +teacher: + model: + name: qwen3 + flavor: 8B + hf_assets_path: hf://${teacher_model} + training: + seq_len: ${trainer.training.seq_len} + dtype: bfloat16 + gc_freq: 1 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 2 + tensor_parallel_degree: 1 # Use 2 GPUs for teacher + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + checkpoint: + enable: true + initial_load_path: hf://${teacher_model} + initial_load_in_hf: true + +# Resource allocations (3 GPUs total) +services: + student_generator: + procs: 1 # Student inference: 1 GPU + num_replicas: 1 + mesh_name: student_generator + with_gpus: true + teacher: + procs: 2 # Teacher: 2 GPUs with TP + num_replicas: 1 + mesh_name: teacher + with_gpus: true + trainer: + procs: 2 # Student training: shares GPU with student_generator + num_replicas: 1 + mesh_name: trainer + with_gpus: true diff --git a/src/forge/actors/reference_model.py b/src/forge/actors/reference_model.py index 02a6e1410..8684c307e 100644 --- a/src/forge/actors/reference_model.py +++ b/src/forge/actors/reference_model.py @@ -13,6 +13,13 @@ from dataclasses import dataclass, field, fields import torch +import torch.nn.functional as F + +from forge.controller import ForgeActor +from forge.observability.metrics import record_metric, Reduce +from forge.observability.perf_tracker import Tracer + +# from forge.util.ops import compute_logprobs from monarch.actor import current_rank, current_size, endpoint from torch.distributed.tensor import DTensor @@ -27,11 +34,6 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.controller import ForgeActor -from forge.observability.metrics import record_metric, Reduce -from forge.observability.perf_tracker import Tracer -from forge.util.ops import compute_logprobs - logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -180,15 +182,77 @@ async def forward( with torch.inference_mode(): logits = self.model(input_ids) self.step += 1 - if isinstance(logits, DTensor): - logits = logits.full_tensor() + # if isinstance(logits, DTensor): + # logits = logits.full_tensor() t.step("forward") if not return_logprobs: t.stop() + if isinstance(logits, DTensor): + return logits.full_tensor() return logits else: - logprobs = compute_logprobs(logits, input_ids[:, max_req_tokens:]) + logprobs = compute_logprobs_chunked(logits, input_ids[:, max_req_tokens:]) t.step("compute_logprobs") t.stop() return logprobs + + +def compute_logprobs_chunked( + logits: torch.Tensor | DTensor, + input_ids: torch.Tensor, + temperature: float = 1.0, + align: bool = True, + chunk_size: int = 512, +) -> torch.Tensor: + """ + Memory-efficient version that processes logits in chunks along the sequence dimension. + Useful for very long sequences where even the DTensor operations might cause memory issues. + + Args: + chunk_size: Number of tokens to process at once. Lower values use less memory. + """ + is_dtensor = isinstance(logits, DTensor) + + # Align logits with input_ids if requested + if align: + target_len = input_ids.size(1) + logits = logits[:, -target_len - 1 : -1, :] + if not is_dtensor: + logits = logits.to(input_ids.device) + + batch_size, seq_len, vocab_size = logits.shape + + # Initialize output tensor + logprobs = torch.zeros( + batch_size, seq_len, dtype=torch.float32, device=logits.device + ) + + # Process in chunks + for start_idx in range(0, seq_len, chunk_size): + end_idx = min(start_idx + chunk_size, seq_len) + + # Get chunk of logits and input_ids + logits_chunk = logits[:, start_idx:end_idx, :] + input_chunk = input_ids[:, start_idx:end_idx] + + # Scale and convert to fp32 + scaled_chunk = (logits_chunk / temperature).float() + + # Compute log probabilities for this chunk + chunk_size_actual = end_idx - start_idx + flat_logits = scaled_chunk.reshape(-1, vocab_size) + flat_targets = input_chunk.reshape(-1).long() + + chunk_logprobs = -F.cross_entropy( + flat_logits, + flat_targets, + reduction="none", + ) + + # Store in output tensor + logprobs[:, start_idx:end_idx] = chunk_logprobs.reshape( + batch_size, chunk_size_actual + ) + + return logprobs diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index e45ddd28e..7f05bf102 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -16,6 +16,18 @@ import torch.distributed.checkpoint as dcp import torchstore as ts +from forge.actors._torchstore_utils import ( + DcpHandle, + get_dcp_whole_state_dict_key, + get_param_key, + rdma_available, +) + +from forge.controller import ForgeActor +from forge.data.utils import batch_to_device +from forge.observability.metrics import record_metric, Reduce +from forge.observability.perf_tracker import Tracer + from monarch.actor import endpoint from torch import Tensor from torch.distributed.checkpoint._nested_dict import flatten_state_dict @@ -36,18 +48,6 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.actors._torchstore_utils import ( - DcpHandle, - get_dcp_whole_state_dict_key, - get_param_key, - rdma_available, -) - -from forge.controller import ForgeActor -from forge.data.utils import batch_to_device -from forge.observability.metrics import record_metric, Reduce -from forge.observability.perf_tracker import Tracer - logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -176,8 +176,7 @@ async def train_step( # TODO: delete item() to avoid cpu-gpu sync loss = loss.detach().item() - record_metric("rl_trainer/count_training_steps", 1, Reduce.SUM) - record_metric("rl_trainer/avg_grpo_loss", loss, Reduce.MEAN) + record_metric("rl_trainer/avg_loss", loss, Reduce.MEAN) # These are placeholder values until the loss function exposes these metrics # record_metric("rl_trainer/step/avg_kl_divergence", 0.0, Reduce.MEAN) From 2a46586dc950587739613118450c598e30bf3fb6 Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 6 Nov 2025 09:02:51 -0800 Subject: [PATCH 2/2] Stub --- apps/on_policy_distillation/README.md | 107 ++++++++++++++++++ apps/on_policy_distillation/config.py | 7 ++ .../main.py | 28 +++-- .../qwen_0_6b_to_8b.yaml | 43 ++++--- apps/sft/main.py | 33 +++++- src/forge/actors/reference_model.py | 8 +- src/forge/actors/trainer.py | 24 ++-- src/forge/data/datasets/__init__.py | 9 +- src/forge/data/datasets/sft_dataset.py | 69 +++++++++++ 9 files changed, 273 insertions(+), 55 deletions(-) create mode 100644 apps/on_policy_distillation/README.md create mode 100644 apps/on_policy_distillation/config.py rename apps/{on-policy-distillation => on_policy_distillation}/main.py (90%) rename apps/{on-policy-distillation => on_policy_distillation}/qwen_0_6b_to_8b.yaml (75%) diff --git a/apps/on_policy_distillation/README.md b/apps/on_policy_distillation/README.md new file mode 100644 index 000000000..e251e0a0f --- /dev/null +++ b/apps/on_policy_distillation/README.md @@ -0,0 +1,107 @@ +# On-Policy Distillation for Math Reasoning + +This app implements on-policy distillation (OPD) following the approach described in the [Thinking Machines blog post](https://thinkingmachines.ai/blog/on-policy-distillation/). OPD combines the benefits of on-policy training with dense reward signals for efficient post-training. + +## Overview + +On-policy distillation trains a student model by: +1. Sampling trajectories from the student model itself +2. Using a teacher model to grade each token with dense rewards (per-token KL divergence) +3. Training the student to minimize reverse KL with the teacher + +This approach is **10-30x more compute efficient** than traditional RL while achieving comparable or better performance. + +## Experimental Setup + +### Models +- **Student**: Qwen3-0.6B-Base (or Qwen3-8B for larger experiments) +- **Teacher**: Qwen3-8B (or Qwen3-32B) +- **Evaluation**: AIME'24 benchmark + +### Training Pipeline + +#### Phase 1: Supervised Fine-Tuning (SFT) +First, establish a strong baseline through off-policy distillation: + +```bash +python -m apps.sft.main --config apps/sft/qwen3_0_6.yaml +``` + +- **Dataset**: OpenThoughts3-1.2M (400k prompts) +- **Expected Performance**: ~60% on AIME'24 +- **Purpose**: Teaches the model basic math reasoning patterns + +#### Phase 2: On-Policy Distillation +Refine the model using on-policy learning with dense supervision: + +```bash +python -m apps.on-policy-distillation.main --config apps/on-policy-distillation/qwen_opd.yaml +``` + +- **Starting Point**: SFT checkpoint from Phase 1 +- **Dataset**: Math prompts (from OpenThoughts3 or DeepMath, but only prompts - not solutions) +- **Training**: ~150 steps (77k prompts with 4 samples each) +- **Expected Performance**: ~70% on AIME'24 + +### Key Implementation Details + +1. **Loss Function**: Per-token reverse KL divergence + ```python + reverse_kl = -(student_logprobs - teacher_logprobs) + ``` + +2. **Sampling**: Generate multiple trajectories per prompt (n=16 in config) + +3. **No Discount Factor**: Optimize only immediate next token (discount=0) + +4. **Efficient Batching**: Can use smaller batch sizes than RL due to dense rewards + +## Evaluation + +Evaluate on AIME'24 benchmark after each phase: + +```bash +python -m apps.eval.aime --checkpoint +``` + +## Expected Results + +| Method | AIME'24 Score | Training Cost | +|--------|---------------|---------------| +| SFT (400k prompts) | ~60% | Baseline | +| SFT (2M prompts, extrapolated) | ~70% | 5x baseline | +| OPD (150 steps) | ~70% | 0.1-0.3x baseline | + +## Key Advantages + +- **Compute Efficiency**: 10-30x reduction vs traditional RL +- **Dense Supervision**: Learns from every token, not just final rewards +- **Data Efficiency**: Can reuse prompts multiple times effectively +- **Stability**: More stable training than sparse RL rewards + +## Notes for Reproduction + +1. **Ensure proper initialization**: Load the SFT checkpoint before starting OPD +2. **Use prompts only**: During OPD, sample completions from student, don't use dataset solutions +3. **Teacher quality matters**: Better teachers provide better supervision +4. **Monitor reverse KL**: Should decrease to near-zero as training progresses + +## References + +- [On-Policy Distillation Blog Post](https://thinkingmachines.ai/blog/on-policy-distillation/) +- [Tinker Cookbook](https://github.com/thinking-machines-lab/tinker-cookbook) +- [OpenThoughts3 Dataset](https://huggingface.co/datasets/open-thoughts/OpenThoughts3-1.2M) + +--- + +**Important Code Modification Needed**: Your current OPD implementation should: +1. Load from an SFT checkpoint (not raw base model) +2. Extract only prompts from the dataset (not use the solutions) +3. Add proper checkpoint loading in the trainer config: + +```yaml +trainer: + checkpoint: + initial_load_path: ./checkpoint_student/sft_final # Load SFT checkpoint + # ... rest of config +``` diff --git a/apps/on_policy_distillation/config.py b/apps/on_policy_distillation/config.py new file mode 100644 index 000000000..f1b3e3b47 --- /dev/null +++ b/apps/on_policy_distillation/config.py @@ -0,0 +1,7 @@ +from dataclasses import dataclass + + +@dataclass +class DatasetConfig: + source: str + split: str = "train" diff --git a/apps/on-policy-distillation/main.py b/apps/on_policy_distillation/main.py similarity index 90% rename from apps/on-policy-distillation/main.py rename to apps/on_policy_distillation/main.py index ce5bef000..71ad7dbd2 100644 --- a/apps/on-policy-distillation/main.py +++ b/apps/on_policy_distillation/main.py @@ -1,4 +1,6 @@ import asyncio +import itertools +import time from dataclasses import dataclass from typing import Any @@ -63,6 +65,9 @@ def collate( teacher_logprobs = [t.teacher_logprobs for t in batch] teacher_logprobs = torch.stack(teacher_logprobs) + # student_logprobs = [t.completion.logprobs for t in batch] + # student_logprobs = torch.stack(student_logprobs) + pad_id = batch[0].pad_id padding_mask = response != pad_id @@ -70,6 +75,7 @@ def collate( target = { "response": response, "teacher_logprobs": teacher_logprobs, + # "student_logprobs": student_logprobs, "padding_mask": padding_mask, } inputs.append(input) @@ -81,6 +87,7 @@ def importance_sampling_loss( logits: torch.Tensor, response: torch.Tensor, teacher_logprobs: torch.Tensor, + # student_logprobs: torch.Tensor, padding_mask: torch.Tensor, **kwargs, ) -> torch.Tensor: @@ -135,32 +142,28 @@ async def main(cfg: DictConfig): tokenizer = get_tokenizer(cfg.student_model) pad_id = tokenizer.pad_token_id dataset = load_dataset(cfg.dataset.path, split=cfg.dataset.get("split", "train")) - dataset = dataset.filter(lambda x: x["domain"] == cfg.dataset["domain"]) + # dataset = dataset.filter(lambda x: x["domain"] == cfg.dataset["domain"]) dataset_iter = iter(dataset) print("All services initialized successfully!") step = 0 for epoch in range(max_steps): + # start time + start = time.perf_counter() if step >= max_steps: break - # Collect rollout trajectories = [] while len(trajectories) < train_batch_size: try: sample = next(dataset_iter) - # Extract the human prompt from OpenThoughts format - conversations = sample.get("conversations", []) - if conversations and len(conversations) > 0: - prompt = conversations[0].get("value", "") - else: - prompt = sample.get("prompt", sample.get("text", "")) + conversation = sample["conversations"] + prompt = conversation[0]["value"] - print(f"Starting request with prompt: {prompt}") - completions = await student_generator.generate.route(prompt) + completions = await student_generator.generate.fanout(prompt) - for completion in completions: + for completion in itertools.chain(*completions): # Create trajectory with raw completion trajectory = Trajectory( pad_id=pad_id, @@ -201,6 +204,9 @@ async def main(cfg: DictConfig): await student_trainer.push_weights.call(step) await student_generator.update_weights.fanout(step) + end = time.perf_counter() + print(f"Step {step} took {end - start} seconds") + await mlogger.flush.call_one(step) print(f"Training completed after {step} steps") diff --git a/apps/on-policy-distillation/qwen_0_6b_to_8b.yaml b/apps/on_policy_distillation/qwen_0_6b_to_8b.yaml similarity index 75% rename from apps/on-policy-distillation/qwen_0_6b_to_8b.yaml rename to apps/on_policy_distillation/qwen_0_6b_to_8b.yaml index 2df9c5196..f75a3091c 100644 --- a/apps/on-policy-distillation/qwen_0_6b_to_8b.yaml +++ b/apps/on_policy_distillation/qwen_0_6b_to_8b.yaml @@ -2,17 +2,16 @@ # >>> python -m apps.on-policy-distillation.main --config apps/on-policy-distillation/qwen_0_6b_to_8b.yaml # Global configuration -train_batch_size: 4 # Number of trajectories per training step -max_req_tokens: 512 -max_res_tokens: 65536 -student_model: "Qwen/Qwen3-1.7B" +train_batch_size: 64 # Number of trajectories per training step +max_req_tokens: 2048 +max_res_tokens: 4096 +student_model: "Qwen/Qwen3-0.6B" teacher_model: "Qwen/Qwen3-8B" # Dataset configuration dataset: path: "open-thoughts/OpenThoughts3-1.2M" split: "train" - domain: "math" # Student Generator configuration (inference model) student_generator: @@ -22,7 +21,7 @@ student_generator: pipeline_parallel_size: 1 enforce_eager: false sampling_params: - n: 2 # Single response per prompt + n: 16 max_tokens: ${max_res_tokens} temperature: 1.0 top_p: 0.95 @@ -31,7 +30,7 @@ student_generator: trainer: model: name: qwen3 - flavor: 1.7B + flavor: 0.6B hf_assets_path: hf://${student_model} optimizer: name: AdamW @@ -41,16 +40,16 @@ trainer: warmup_steps: 10 training: local_batch_size: ${train_batch_size} # Per-device batch size - seq_len: 66048 + seq_len: 8192 max_norm: 1.0 steps: 10000 dtype: bfloat16 - gc_freq: 1 + gc_freq: 5 compile: enable: false parallelism: data_parallel_replicate_degree: 1 - data_parallel_shard_degree: 2 + data_parallel_shard_degree: 1 tensor_parallel_degree: 1 pipeline_parallel_degree: 1 context_parallel_degree: 1 @@ -58,15 +57,15 @@ trainer: disable_loss_parallel: true checkpoint: enable: true - # folder: ./checkpoint_student + folder: ./checkpoint_student initial_load_path: hf://${student_model} initial_load_in_hf: true last_save_in_hf: true - interval: 500 + interval: 250 async_mode: "disabled" activation_checkpoint: - mode: selective - selective_ac_option: op + mode: none + # selective_ac_option: op # Teacher model configuration teacher: @@ -77,13 +76,13 @@ teacher: training: seq_len: ${trainer.training.seq_len} dtype: bfloat16 - gc_freq: 1 + gc_freq: 10 compile: enable: false parallelism: data_parallel_replicate_degree: 1 - data_parallel_shard_degree: 2 - tensor_parallel_degree: 1 # Use 2 GPUs for teacher + data_parallel_shard_degree: 1 + tensor_parallel_degree: 1 pipeline_parallel_degree: 1 context_parallel_degree: 1 expert_parallel_degree: 1 @@ -95,17 +94,17 @@ teacher: # Resource allocations (3 GPUs total) services: student_generator: - procs: 1 # Student inference: 1 GPU - num_replicas: 1 + procs: 1 + num_replicas: 4 mesh_name: student_generator with_gpus: true teacher: - procs: 2 # Teacher: 2 GPUs with TP - num_replicas: 1 + procs: 1 + num_replicas: 2 mesh_name: teacher with_gpus: true trainer: - procs: 2 # Student training: shares GPU with student_generator + procs: 1 num_replicas: 1 mesh_name: trainer with_gpus: true diff --git a/apps/sft/main.py b/apps/sft/main.py index 93ba05eed..c63a59b79 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -25,7 +25,11 @@ from forge.controller import ForgeActor from forge.data.collate import collate_packed from forge.data.datasets.packed import PackedDataset, TextPacker -from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset +from forge.data.datasets.sft_dataset import ( + AlpacaToMessages, + OpenThoughtsToMessages, + sft_iterable_dataset, +) from forge.data.tokenizer import HuggingFaceModelTokenizer from forge.observability import get_or_create_metric_logger, record_metric, Reduce from forge.util.config import parse @@ -165,13 +169,32 @@ def setup_data(self): ), ) + # Get dataset configuration from job_config + dataset_config = self.job_config["dataset"] + dataset_path = dataset_config["path"] + dataset_split = dataset_config["split"] + message_transform_type = dataset_config.get("message_transform", "alpaca") + masking_strategy = dataset_config.get("masking_strategy", "train_on_assistant") + + # Select the appropriate message transform + if message_transform_type == "openthoughts": + message_transform = OpenThoughtsToMessages( + masking_strategy=masking_strategy + ) + elif message_transform_type == "alpaca": + message_transform = AlpacaToMessages(masking_strategy=masking_strategy) + else: + raise ValueError( + f"Unknown message_transform type: {message_transform_type}" + ) + dataset = sft_iterable_dataset( model_transform=tokenizer, - message_transform=AlpacaToMessages(), - path="yahma/alpaca-cleaned", - split="train", + message_transform=message_transform, + path=dataset_path, + split=dataset_split, ) - packer = TextPacker(padding_idx=0) + packer = TextPacker(padding_idx=151643) dataset = PackedDataset( dataset=dataset, packer=packer, diff --git a/src/forge/actors/reference_model.py b/src/forge/actors/reference_model.py index 8684c307e..54a079f81 100644 --- a/src/forge/actors/reference_model.py +++ b/src/forge/actors/reference_model.py @@ -15,10 +15,6 @@ import torch import torch.nn.functional as F -from forge.controller import ForgeActor -from forge.observability.metrics import record_metric, Reduce -from forge.observability.perf_tracker import Tracer - # from forge.util.ops import compute_logprobs from monarch.actor import current_rank, current_size, endpoint from torch.distributed.tensor import DTensor @@ -34,6 +30,10 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig +from forge.controller import ForgeActor +from forge.observability.metrics import record_metric, Reduce +from forge.observability.perf_tracker import Tracer + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 7f05bf102..c58600f51 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -16,18 +16,6 @@ import torch.distributed.checkpoint as dcp import torchstore as ts -from forge.actors._torchstore_utils import ( - DcpHandle, - get_dcp_whole_state_dict_key, - get_param_key, - rdma_available, -) - -from forge.controller import ForgeActor -from forge.data.utils import batch_to_device -from forge.observability.metrics import record_metric, Reduce -from forge.observability.perf_tracker import Tracer - from monarch.actor import endpoint from torch import Tensor from torch.distributed.checkpoint._nested_dict import flatten_state_dict @@ -48,6 +36,18 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig +from forge.actors._torchstore_utils import ( + DcpHandle, + get_dcp_whole_state_dict_key, + get_param_key, + rdma_available, +) + +from forge.controller import ForgeActor +from forge.data.utils import batch_to_device +from forge.observability.metrics import record_metric, Reduce +from forge.observability.perf_tracker import Tracer + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) diff --git a/src/forge/data/datasets/__init__.py b/src/forge/data/datasets/__init__.py index 8d62e221d..64631a6d1 100644 --- a/src/forge/data/datasets/__init__.py +++ b/src/forge/data/datasets/__init__.py @@ -7,13 +7,20 @@ from .dataset import DatasetInfo, InfiniteTuneIterableDataset, InterleavedDataset from .hf_dataset import HfIterableDataset from .packed import PackedDataset -from .sft_dataset import sft_iterable_dataset, SFTOutputTransform +from .sft_dataset import ( + AlpacaToMessages, + OpenThoughtsToMessages, + sft_iterable_dataset, + SFTOutputTransform, +) __all__ = [ + "AlpacaToMessages", "DatasetInfo", "HfIterableDataset", "InterleavedDataset", "InfiniteTuneIterableDataset", + "OpenThoughtsToMessages", "PackedDataset", "SFTOutputTransform", "sft_iterable_dataset", diff --git a/src/forge/data/datasets/sft_dataset.py b/src/forge/data/datasets/sft_dataset.py index 00278c1e5..b2f1beef2 100644 --- a/src/forge/data/datasets/sft_dataset.py +++ b/src/forge/data/datasets/sft_dataset.py @@ -105,6 +105,75 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: return {"messages": messages} +class OpenThoughtsToMessages: + """ + Message transform class for OpenThoughts-style datasets with a "conversations" column + containing a list of dictionaries with "from" and "value" fields. + + Args: + column_map (dict[str, str] | None): a mapping to change the expected "conversations" + column name to the actual column name in the dataset. Default is None, + keeping the default column name. + masking_strategy (str): masking strategy to use for model training. + Must be one of: `train_on_all`, `train_on_assistant`, `train_on_last`. + Default is "train_on_assistant". + + - ``train_on_all``: both user and assistant messages are unmasked + - ``train_on_assistant``: user messages are masked, only assistant messages are unmasked + - ``train_on_last``: only the last assistant message is unmasked + """ + + def __init__( + self, + column_map: dict[str, str] | None = None, + masking_strategy: str = "train_on_assistant", + ): + self.masking_strategy = masking_strategy + if column_map: + if "conversations" not in column_map: + raise ValueError( + f"Expected a key of 'conversations' in column_map but found {column_map.keys()}." + ) + self._column_map = column_map + else: + self._column_map = { + "conversations": "conversations", + } + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + conversations = sample[self._column_map["conversations"]] + + if not isinstance(conversations, list): + raise ValueError( + f"Expected 'conversations' to be a list, got {type(conversations)}" + ) + + messages = [] + for message_dict in conversations: + role = message_dict.get("from", "") + content = message_dict.get("value", "") + + # Map OpenThoughts roles to standard roles + if role in ["human", "user"]: + role = "user" + elif role in ["gpt", "assistant", "model"]: + role = "assistant" + else: + # Skip unknown roles + continue + + messages.append( + TuneMessage( + role=role, + content=content, + eot=True, + ) + ) + + mask_messages(messages, self.masking_strategy) + return {"messages": messages} + + class SFTOutputTransform: """Applied to each dataset sample to build the `"labels"` tensor for causal-LM SFT training.