-
Notifications
You must be signed in to change notification settings - Fork 51
[WIP] On Policy Distillation (sync) #527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
joecummings
wants to merge
3
commits into
meta-pytorch:main
Choose a base branch
from
joecummings:on-policy-distillation
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| # On-Policy Distillation for Math Reasoning | ||
|
|
||
| This app implements on-policy distillation (OPD) following the approach described in the [Thinking Machines blog post](https://thinkingmachines.ai/blog/on-policy-distillation/). OPD combines the benefits of on-policy training with dense reward signals for efficient post-training. | ||
|
|
||
| ## Overview | ||
|
|
||
| On-policy distillation trains a student model by: | ||
| 1. Sampling trajectories from the student model itself | ||
| 2. Using a teacher model to grade each token with dense rewards (per-token KL divergence) | ||
| 3. Training the student to minimize reverse KL with the teacher | ||
|
|
||
| This approach is **10-30x more compute efficient** than traditional RL while achieving comparable or better performance. | ||
|
|
||
| ## Experimental Setup | ||
|
|
||
| ### Models | ||
| - **Student**: Qwen3-0.6B-Base (or Qwen3-8B for larger experiments) | ||
| - **Teacher**: Qwen3-8B (or Qwen3-32B) | ||
| - **Evaluation**: AIME'24 benchmark | ||
|
|
||
| ### Training Pipeline | ||
|
|
||
| #### Phase 1: Supervised Fine-Tuning (SFT) | ||
| First, establish a strong baseline through off-policy distillation: | ||
|
|
||
| ```bash | ||
| python -m apps.sft.main --config apps/sft/qwen3_0_6.yaml | ||
| ``` | ||
|
|
||
| - **Dataset**: OpenThoughts3-1.2M (400k prompts) | ||
| - **Expected Performance**: ~60% on AIME'24 | ||
| - **Purpose**: Teaches the model basic math reasoning patterns | ||
|
|
||
| #### Phase 2: On-Policy Distillation | ||
| Refine the model using on-policy learning with dense supervision: | ||
|
|
||
| ```bash | ||
| python -m apps.on-policy-distillation.main --config apps/on-policy-distillation/qwen_opd.yaml | ||
| ``` | ||
|
|
||
| - **Starting Point**: SFT checkpoint from Phase 1 | ||
| - **Dataset**: Math prompts (from OpenThoughts3 or DeepMath, but only prompts - not solutions) | ||
| - **Training**: ~150 steps (77k prompts with 4 samples each) | ||
| - **Expected Performance**: ~70% on AIME'24 | ||
|
|
||
| ### Key Implementation Details | ||
|
|
||
| 1. **Loss Function**: Per-token reverse KL divergence | ||
| ```python | ||
| reverse_kl = -(student_logprobs - teacher_logprobs) | ||
| ``` | ||
|
|
||
| 2. **Sampling**: Generate multiple trajectories per prompt (n=16 in config) | ||
|
|
||
| 3. **No Discount Factor**: Optimize only immediate next token (discount=0) | ||
|
|
||
| 4. **Efficient Batching**: Can use smaller batch sizes than RL due to dense rewards | ||
|
|
||
| ## Evaluation | ||
|
|
||
| Evaluate on AIME'24 benchmark after each phase: | ||
|
|
||
| ```bash | ||
| python -m apps.eval.aime --checkpoint <path_to_checkpoint> | ||
| ``` | ||
|
|
||
| ## Expected Results | ||
|
|
||
| | Method | AIME'24 Score | Training Cost | | ||
| |--------|---------------|---------------| | ||
| | SFT (400k prompts) | ~60% | Baseline | | ||
| | SFT (2M prompts, extrapolated) | ~70% | 5x baseline | | ||
| | OPD (150 steps) | ~70% | 0.1-0.3x baseline | | ||
|
|
||
| ## Key Advantages | ||
|
|
||
| - **Compute Efficiency**: 10-30x reduction vs traditional RL | ||
| - **Dense Supervision**: Learns from every token, not just final rewards | ||
| - **Data Efficiency**: Can reuse prompts multiple times effectively | ||
| - **Stability**: More stable training than sparse RL rewards | ||
|
|
||
| ## Notes for Reproduction | ||
|
|
||
| 1. **Ensure proper initialization**: Load the SFT checkpoint before starting OPD | ||
| 2. **Use prompts only**: During OPD, sample completions from student, don't use dataset solutions | ||
| 3. **Teacher quality matters**: Better teachers provide better supervision | ||
| 4. **Monitor reverse KL**: Should decrease to near-zero as training progresses | ||
|
|
||
| ## References | ||
|
|
||
| - [On-Policy Distillation Blog Post](https://thinkingmachines.ai/blog/on-policy-distillation/) | ||
| - [Tinker Cookbook](https://github.com/thinking-machines-lab/tinker-cookbook) | ||
| - [OpenThoughts3 Dataset](https://huggingface.co/datasets/open-thoughts/OpenThoughts3-1.2M) | ||
|
|
||
| --- | ||
|
|
||
| **Important Code Modification Needed**: Your current OPD implementation should: | ||
| 1. Load from an SFT checkpoint (not raw base model) | ||
| 2. Extract only prompts from the dataset (not use the solutions) | ||
| 3. Add proper checkpoint loading in the trainer config: | ||
|
|
||
| ```yaml | ||
| trainer: | ||
| checkpoint: | ||
| initial_load_path: ./checkpoint_student/sft_final # Load SFT checkpoint | ||
| # ... rest of config | ||
| ``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| from dataclasses import dataclass | ||
|
|
||
|
|
||
| @dataclass | ||
| class DatasetConfig: | ||
| source: str | ||
| split: str = "train" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,222 @@ | ||
| import asyncio | ||
| import itertools | ||
| import time | ||
| from dataclasses import dataclass | ||
| from typing import Any | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| import torchstore as ts | ||
| from datasets import load_dataset | ||
| from forge.actors.generator import Generator | ||
| from forge.actors.reference_model import ReferenceModel | ||
| from forge.actors.trainer import RLTrainer | ||
| from forge.controller.provisioner import init_provisioner, shutdown | ||
| from forge.data_models.completion import Completion | ||
| from forge.observability.metric_actors import get_or_create_metric_logger | ||
| from forge.util.config import parse | ||
| from forge.util.ops import compute_logprobs | ||
| from omegaconf import DictConfig | ||
| from vllm.transformers_utils.tokenizer import get_tokenizer | ||
|
|
||
|
|
||
| @dataclass | ||
| class Trajectory: | ||
| pad_id: int | ||
| request_len: int | ||
| response_len: int | ||
| # Processed data | ||
| completion: Completion | None = None | ||
| teacher_logprobs: torch.Tensor | None = None | ||
|
|
||
| @property | ||
| def request_tensor(self) -> torch.Tensor: | ||
| tensor: torch.Tensor = self.completion.prompt_ids.to(torch.long) | ||
| if tensor.shape[0] < self.request_len: # left pad | ||
| diff = self.request_len - tensor.shape[0] | ||
| tensor = F.pad(tensor, (diff, 0), value=self.pad_id) | ||
| elif tensor.shape[0] > self.request_len: # truncate | ||
| tensor = tensor[-self.request_len :] | ||
| return tensor | ||
|
|
||
| @property | ||
| def response_tensor(self) -> torch.Tensor: | ||
| tensor: torch.Tensor = self.completion.token_ids.to(torch.long) | ||
| if tensor.shape[0] < self.response_len: # right pad | ||
| diff = self.response_len - tensor.shape[0] | ||
| tensor = F.pad(tensor, (0, diff), value=self.pad_id) | ||
| elif tensor.shape[0] > self.response_len: # truncate | ||
| tensor = tensor[: self.response_len] | ||
| return tensor | ||
|
|
||
|
|
||
| def collate( | ||
| batches: list[list[Trajectory]], | ||
| ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: | ||
| inputs = [] | ||
| targets = [] | ||
| for batch in batches: | ||
| request = [t.request_tensor for t in batch] | ||
| request = torch.stack(request) | ||
|
|
||
| response = [t.response_tensor for t in batch] | ||
| response = torch.stack(response) | ||
|
|
||
| teacher_logprobs = [t.teacher_logprobs for t in batch] | ||
| teacher_logprobs = torch.stack(teacher_logprobs) | ||
|
|
||
| # student_logprobs = [t.completion.logprobs for t in batch] | ||
| # student_logprobs = torch.stack(student_logprobs) | ||
|
|
||
| pad_id = batch[0].pad_id | ||
| padding_mask = response != pad_id | ||
|
|
||
| input = {"tokens": torch.cat([request, response], dim=1)} | ||
| target = { | ||
| "response": response, | ||
| "teacher_logprobs": teacher_logprobs, | ||
| # "student_logprobs": student_logprobs, | ||
| "padding_mask": padding_mask, | ||
| } | ||
| inputs.append(input) | ||
| targets.append(target) | ||
| return inputs, targets | ||
|
|
||
|
|
||
| def importance_sampling_loss( | ||
| logits: torch.Tensor, | ||
| response: torch.Tensor, | ||
| teacher_logprobs: torch.Tensor, | ||
| # student_logprobs: torch.Tensor, | ||
| padding_mask: torch.Tensor, | ||
| **kwargs, | ||
| ) -> torch.Tensor: | ||
| student_logprobs = compute_logprobs(logits, response) | ||
| reverse_kl = -(student_logprobs - teacher_logprobs) | ||
| prob_ratio = torch.exp(teacher_logprobs - student_logprobs) | ||
| per_token_loss = prob_ratio * reverse_kl | ||
|
|
||
| # Apply mask and compute mean over valid tokens | ||
| masked_loss = per_token_loss * padding_mask | ||
| num_valid_tokens = padding_mask.sum(dim=1, keepdim=True).clamp(min=1.0) | ||
| loss_per_sequence = masked_loss.sum(dim=1, keepdim=True) / num_valid_tokens | ||
| loss = loss_per_sequence.mean() | ||
|
|
||
| return loss | ||
|
|
||
|
|
||
| async def main(cfg: DictConfig): | ||
| train_batch_size = cfg.train_batch_size | ||
| max_steps = cfg.trainer.training.steps | ||
| max_req_tokens = cfg.max_req_tokens | ||
| max_res_tokens = cfg.max_res_tokens | ||
|
|
||
| provisioner = await init_provisioner() | ||
| mlogger = await get_or_create_metric_logger(process_name="Controller") | ||
| await mlogger.init_backends.call_one( | ||
| { | ||
| "wandb": {"project": "opd-v0", "logging_mode": "global_reduce"}, | ||
| "console": {"logging_mode": "global_reduce"}, | ||
| } | ||
| ) | ||
| student_trainer, student_generator, teacher = await asyncio.gather( | ||
| RLTrainer.options(**cfg.services.trainer).as_actor( | ||
| **cfg.trainer, loss=importance_sampling_loss | ||
| ), | ||
| Generator.options(**cfg.services.student_generator).as_service( | ||
| **cfg.student_generator | ||
| ), | ||
| ReferenceModel.options(**cfg.services.teacher).as_service(**cfg.teacher), | ||
| ) | ||
|
|
||
| # Setup torchstore for weight management | ||
| trainer_num_procs = cfg.services.trainer["procs"] | ||
| trainer_host_mesh_name = cfg.services.trainer["mesh_name"] | ||
| trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name) | ||
| await ts.initialize( | ||
| mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}), | ||
| strategy=ts.LocalRankStrategy(), | ||
| ) | ||
|
|
||
| # Load dataset | ||
| tokenizer = get_tokenizer(cfg.student_model) | ||
| pad_id = tokenizer.pad_token_id | ||
| dataset = load_dataset(cfg.dataset.path, split=cfg.dataset.get("split", "train")) | ||
| # dataset = dataset.filter(lambda x: x["domain"] == cfg.dataset["domain"]) | ||
| dataset_iter = iter(dataset) | ||
|
|
||
| print("All services initialized successfully!") | ||
|
|
||
| step = 0 | ||
| for epoch in range(max_steps): | ||
| # start time | ||
| start = time.perf_counter() | ||
| if step >= max_steps: | ||
| break | ||
|
|
||
| trajectories = [] | ||
| while len(trajectories) < train_batch_size: | ||
| try: | ||
| sample = next(dataset_iter) | ||
| conversation = sample["conversations"] | ||
| prompt = conversation[0]["value"] | ||
|
|
||
| completions = await student_generator.generate.fanout(prompt) | ||
|
|
||
| for completion in itertools.chain(*completions): | ||
| # Create trajectory with raw completion | ||
| trajectory = Trajectory( | ||
| pad_id=pad_id, | ||
| request_len=max_req_tokens, | ||
| response_len=max_res_tokens, | ||
| completion=completion, | ||
| ) | ||
|
|
||
| # Build padded input for teacher using trajectory properties | ||
| input_ids = torch.cat( | ||
| [ | ||
| trajectory.request_tensor.unsqueeze(0), | ||
| trajectory.response_tensor.unsqueeze(0), | ||
| ], | ||
| dim=1, | ||
| ) | ||
|
|
||
| teacher_logprobs = await teacher.forward.route( | ||
| input_ids, max_req_tokens, return_logprobs=True | ||
| ) | ||
|
|
||
| trajectory.teacher_logprobs = teacher_logprobs | ||
| trajectories.append(trajectory) | ||
| except StopIteration: | ||
| print("Dataset exhausted, resetting iterator") | ||
| dataset_iter = iter(dataset) | ||
|
|
||
| # Train on collected trajectories | ||
| trajectories = [ | ||
| trajectories[i::train_batch_size] for i in range(train_batch_size) | ||
| ] | ||
| inputs, targets = collate(trajectories) | ||
| await student_trainer.train_step.call(inputs, targets) | ||
|
|
||
| step += 1 | ||
|
|
||
| # Push weights to student generator | ||
| await student_trainer.push_weights.call(step) | ||
| await student_generator.update_weights.fanout(step) | ||
|
|
||
| end = time.perf_counter() | ||
| print(f"Step {step} took {end - start} seconds") | ||
|
|
||
| await mlogger.flush.call_one(step) | ||
|
|
||
| print(f"Training completed after {step} steps") | ||
| await shutdown() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
| @parse | ||
| def _main(cfg): | ||
| asyncio.run(main(cfg)) | ||
|
|
||
| _main() | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.