diff --git a/apps/distillation/main.py b/apps/distillation/main.py
new file mode 100644
index 000000000..55eda8af9
--- /dev/null
+++ b/apps/distillation/main.py
@@ -0,0 +1,462 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Usage: python -m apps.distillation.main --config apps/distillation/qwen3_distillation.yaml
+
+import asyncio
+import time
+import uuid
+from dataclasses import dataclass
+from typing import Any
+
+import torch
+import torch.nn.functional as F
+import torchstore as ts
+from datasets import load_dataset
+from forge.actors._torchstore_utils import (
+ get_dcp_whole_state_dict_key,
+ get_param_prefix,
+)
+from forge.actors.generator import Generator
+from forge.actors.reference_model import ReferenceModel
+from forge.actors.replay_buffer import ReplayBuffer
+from forge.actors.trainer import RLTrainer
+from forge.controller.actor import ForgeActor
+from forge.controller.provisioner import init_provisioner, shutdown
+from forge.data_models.completion import Completion
+from forge.observability.metric_actors import get_or_create_metric_logger
+from forge.observability.metrics import record_metric, Reduce
+from forge.observability.perf_tracker import Tracer
+
+from forge.types import LauncherConfig, ProvisionerConfig
+from forge.util.config import parse
+from forge.util.ops import compute_logprobs
+from monarch.actor import endpoint
+from omegaconf import DictConfig
+from vllm.transformers_utils.tokenizer import get_tokenizer
+
+
+@dataclass
+class Episode:
+ episode_id: str
+ pad_id: int
+ request_len: int
+ response_len: int
+ target: Any | None = None
+ # Processed data
+ completion: Completion | None = None
+ teacher_logprobs: torch.Tensor | None = None
+
+ @property
+ def policy_version(self) -> int | None:
+ return self.completion.generator_version
+
+ @property
+ def request_tensor(self) -> torch.Tensor:
+ tensor: torch.Tensor = self.completion.prompt_ids.to(torch.long)
+ if tensor.shape[0] < self.request_len: # left pad
+ diff = self.request_len - tensor.shape[0]
+ tensor = F.pad(tensor, (diff, 0), value=self.pad_id)
+ return tensor
+
+ @property
+ def response_tensor(self) -> torch.Tensor:
+ tensor: torch.Tensor = self.completion.token_ids.to(torch.long)
+ if tensor.shape[0] < self.response_len: # right pad
+ diff = self.response_len - tensor.shape[0]
+ tensor = F.pad(tensor, (0, diff), value=self.pad_id)
+ return tensor
+
+
+# Represents the group (G) of episodes
+Group = list[Episode]
+
+# Represents the Student Model to generate data and train
+StudentGenerator = Generator
+
+
+def collate(
+ batches: list[Group],
+) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
+ """
+ Collates a list of batches into a single batch of inputs and targets for distillation.
+ Each batch is a list of episodes, and each episode is a dict of tensors.
+ """
+ inputs = []
+ targets = []
+ for batch in batches:
+ request = [e.request_tensor for e in batch]
+ request = torch.stack(request) # [b x s]
+
+ response = [e.response_tensor for e in batch]
+ response = torch.stack(response) # [b x s]
+
+ teacher_logprobs = [e.teacher_logprobs for e in batch]
+ teacher_logprobs = torch.stack(teacher_logprobs).squeeze() # [b x s]
+
+ pad_id = batch[0].pad_id
+ mask = response != pad_id
+
+ input = {"tokens": torch.cat([request, response], dim=1)}
+ target = {
+ "response": response,
+ "teacher_logprobs": teacher_logprobs,
+ "padding_mask": mask,
+ }
+ inputs.append(input)
+ targets.append(target)
+ return inputs, targets
+
+
+def distillation_loss(
+ logits: torch.Tensor,
+ response: torch.Tensor,
+ teacher_logprobs: torch.Tensor,
+ padding_mask: torch.Tensor,
+) -> torch.Tensor:
+ """
+ Online distillation loss using importance sampling with negative KL as reward.
+
+ Formulation:
+ reward = -reverse_kl = -(sampled_logprobs - teacher_logprobs)
+ loss = -E[importance_weight * reward]
+
+ where importance_weight = exp(logprobs - logprobs.detach()) and the KL term
+ is detached (no backprop through it).
+
+ This is similar to GRPO's policy gradient but without the KL penalty term,
+ treating the negative KL as a reward signal.
+
+ Args:
+ logits: Student model logits [batch_size, seq_len, vocab_size]
+ response: Response token ids [batch_size, seq_len]
+ teacher_logprobs: Teacher log probabilities [batch_size, seq_len]
+ padding_mask: Mask for valid (non-padding) tokens [batch_size, seq_len]
+
+ Returns:
+ Importance sampling loss with negative KL as reward
+ """
+ student_logprobs: torch.Tensor = compute_logprobs(logits, response)
+
+ # Compute reward as negative reverse KL (detached)
+ # reverse_kl = sampled_logprobs - teacher_logprobs
+ # reward = -reverse_kl (higher reward when student matches teacher better)
+ reverse_kl = student_logprobs.detach() - teacher_logprobs
+ reward = -reverse_kl
+
+ # Importance sampling loss (like GRPO policy gradient, but without KL penalty)
+ # importance_weight = exp(current_logprobs - sampled_logprobs)
+ per_token_policy_loss = (
+ torch.exp(student_logprobs - student_logprobs.detach()) * reward
+ )
+
+ # Negative sign because we want to maximize reward
+ per_token_loss = -per_token_policy_loss
+
+ # Average over valid (non-padded) tokens
+ loss = (
+ ((per_token_loss * padding_mask).sum(dim=1))
+ / (padding_mask.sum(dim=1).clamp(min=1.0))
+ ).mean()
+
+ return loss
+
+
+@dataclass
+class DatasetActor(ForgeActor):
+ """Actor wrapper for HuggingFace dataset to provide async interface."""
+
+ path: str = "openai/gsm8k"
+ revision: str = "main"
+ data_split: str = "train"
+ streaming: bool = True
+ model: str = "Qwen/Qwen3-1.7B"
+
+ @endpoint
+ def setup(self):
+ self._tokenizer = get_tokenizer(self.model)
+ self._epoch = 0
+
+ def gsm8k_transform(sample):
+ system_prompt = """
+ Put all your scratchpad work between and tags.
+ Your final answer should be between and tags otherwise it will not be scored.
+ """
+ request: str = sample["question"]
+ as_chat = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": request},
+ ]
+ formatted_request = self._tokenizer.apply_chat_template(
+ as_chat,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ target: str = sample["answer"]
+ formatted_target = target.split("#### ")[1]
+ return {"request": formatted_request, "target": formatted_target}
+
+ self._base_dataset = load_dataset(
+ self.path, self.revision, split=self.data_split, streaming=self.streaming
+ )
+ self._base_dataset = self._base_dataset.map(gsm8k_transform)
+ self._base_dataset = self._base_dataset.shuffle()
+ self._iterator = iter(self._base_dataset)
+
+ @endpoint
+ async def sample(self) -> dict[str, str] | None:
+ try:
+ sample = next(self._iterator)
+
+ record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM)
+ record_metric(
+ "dataset/sample/avg_sample_len",
+ len(sample["request"]),
+ Reduce.MEAN,
+ )
+ record_metric("dataset/sample/current_epoch", self._epoch, Reduce.MAX)
+
+ return sample
+ except StopIteration:
+ # Restart iterator for next epoch with reshuffling
+ self._epoch += 1
+ print(
+ f"Dataset epoch {self._epoch - 1} completed. Starting epoch {self._epoch}"
+ )
+ self._base_dataset.set_epoch(self._epoch)
+ self._iterator = iter(self._base_dataset)
+ return next(self._iterator)
+
+ @endpoint
+ async def pad_token(self):
+ return self._tokenizer.pad_token_id
+
+
+async def drop_weights(version: int):
+ print(f"Dropping weights @ version {version}")
+ start_time = time.perf_counter()
+ prefix = get_param_prefix(version)
+ matching_keys = await ts.keys(prefix)
+ # TODO: once we have something like `get_meta()` in torchstore, we can just
+ # query the type of the object instead of relying on keys.
+ dcp_key = get_dcp_whole_state_dict_key(version)
+ if dcp_key in matching_keys:
+ dcp_handle = await ts.get(dcp_key)
+ dcp_handle.drop()
+ for key in matching_keys:
+ await ts.delete(key)
+ elapsed = time.perf_counter() - start_time
+ print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds")
+
+
+async def main(cfg: DictConfig):
+ """Main online distillation training loop with rollout and training processes."""
+ group_size = cfg.group_size
+ max_req_tokens = cfg.max_req_tokens
+ max_res_tokens = cfg.max_res_tokens
+
+ # ---- Global setups ---- #
+ provisioner = None
+ if cfg.get("provisioner", None) is not None:
+ provisioner = await init_provisioner(
+ ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
+ )
+ else:
+ provisioner = await init_provisioner()
+
+ metric_logging_cfg = cfg.get("metric_logging", {})
+ mlogger = await get_or_create_metric_logger(process_name="Controller")
+ await mlogger.init_backends.call_one(metric_logging_cfg)
+
+ # ---- Setup services ---- #
+
+ (
+ dataloader,
+ student_generator,
+ trainer,
+ replay_buffer,
+ teacher_model,
+ ) = await asyncio.gather(
+ DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),
+ StudentGenerator.options(**cfg.services.student_generator).as_service(
+ **cfg.student_generator
+ ),
+ RLTrainer.options(**cfg.actors.trainer).as_actor(
+ **cfg.trainer, loss=distillation_loss
+ ),
+ ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(
+ **cfg.replay_buffer, collate=collate
+ ),
+ ReferenceModel.options(**cfg.services.teacher_model).as_service(
+ **cfg.teacher_model
+ ),
+ )
+
+ # Set max_steps to the configured value, or -1 if not specified or Null
+ max_steps = cfg.trainer.training.steps or -1
+
+ print("All services initialized successfully!")
+ shutdown_event = asyncio.Event()
+ # Here we spawn a torchstore storage volume per trainer process.
+ # We initialize after service initialization because torchstore currently
+ # requires access to the underlying proc meshes in the local rank strategy.
+ # We should be able to hide this in the future.
+ # TODO: support multiple host meshes
+ trainer_num_procs = cfg.actors.trainer["procs"]
+ trainer_host_mesh_name = cfg.actors.trainer["mesh_name"]
+ trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name)
+ await ts.initialize(
+ mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}),
+ strategy=ts.LocalRankStrategy(),
+ )
+ print("Torchstore successfully initialized with local rank strategy")
+
+ # ---- Core distillation loops ---- #
+ async def continuous_rollouts():
+ rollout_count = 0
+ pad_id = await dataloader.pad_token.call_one()
+ while not shutdown_event.is_set():
+ t = Tracer("main_perf/continuous_rollouts")
+ t.start()
+ sample = await dataloader.sample.call_one()
+ if sample is None:
+ print("Dataloader is empty, exiting continuous rollout")
+ return
+
+ t.step("data_loading")
+
+ prompt, target = sample["request"], sample["target"]
+ responses: list[Completion] = await student_generator.generate.route(prompt)
+ t.step("student_generation")
+
+ # Construct episodes with teacher logprobs
+ episodes = []
+ input_ids = torch.ones(
+ (group_size, max_req_tokens + max_res_tokens),
+ dtype=torch.long,
+ )
+ for i, response in enumerate(responses):
+ episode = Episode(
+ episode_id=str(uuid.uuid4()),
+ pad_id=pad_id,
+ request_len=max_req_tokens,
+ response_len=max_res_tokens,
+ target=target,
+ completion=response,
+ )
+ episodes.append(episode)
+
+ # Build input_ids for teacher logprobs
+ input_ids[i, :max_req_tokens] = episode.request_tensor
+ input_ids[i, max_req_tokens:] = episode.response_tensor
+
+ t.step("episode_construction")
+
+ # Get teacher logprobs on student-generated completions
+ teacher_logprobs = await teacher_model.forward.route(
+ input_ids, max_req_tokens, return_logprobs=True
+ )
+ t.step("teacher_model_calculate_logprobs")
+
+ for i, episode in enumerate(episodes):
+ episode.teacher_logprobs = teacher_logprobs[i]
+ await replay_buffer.add.call_one(episode)
+
+ del teacher_logprobs, input_ids
+
+ rollout_count += 1
+ record_metric(
+ "main/continuous_rollouts/count_rollout_iterations", 1, Reduce.SUM
+ )
+ t.stop()
+
+ async def continuous_training():
+ training_step = 0
+ restart_tracer = True # Flag to control when to restart tracer
+
+ while max_steps == -1 or training_step < max_steps:
+ # Restart tracer when needed (initial start or after completing a training step)
+ # Otherwise, we cannot measure time waiting for buffer
+ if restart_tracer:
+ t = Tracer("main_perf/continuous_training")
+ t.start()
+ restart_tracer = False
+
+ batch = await replay_buffer.sample.call_one(
+ curr_policy_version=training_step
+ )
+ if batch is None:
+ await asyncio.sleep(0.1)
+ else:
+ t.step("waiting_for_buffer")
+
+ inputs, targets = batch
+ await trainer.train_step.call(inputs, targets)
+ training_step += 1
+ t.step("train_step")
+
+ await trainer.push_weights.call(training_step)
+ t.step("push_weights")
+
+ await student_generator.update_weights.fanout(training_step)
+ t.step("update_weights")
+
+ if training_step >= 2:
+ await drop_weights(training_step - 1)
+ t.step("drop_weights")
+
+ t.stop()
+ restart_tracer = True
+
+ # Flush metrics every training step to WandB
+ await mlogger.flush.call_one(training_step)
+
+ print(
+ f"Reached training limit ({max_steps} steps). Exiting continuous_training loop."
+ )
+
+ num_rollout_threads = cfg.get("rollout_threads", 1)
+ num_training_threads = cfg.get("training_threads", 1)
+ print(
+ f"Starting online distillation with {num_rollout_threads} rollout threads, {num_training_threads} training threads"
+ )
+ rollout_tasks = [
+ asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads)
+ ]
+ training_task = asyncio.create_task(continuous_training())
+
+ try:
+ await training_task
+ except KeyboardInterrupt:
+ print("Training interrupted by user")
+ finally:
+ print("Shutting down... (this may take a few seconds)")
+ shutdown_event.set()
+
+ try:
+ # Give rollouts up to 5s to finish naturally
+ await asyncio.wait_for(
+ asyncio.gather(*rollout_tasks, return_exceptions=True),
+ timeout=5,
+ )
+ except asyncio.TimeoutError:
+ print("Timeout waiting for rollouts; forcing cancellation...")
+ for t in rollout_tasks:
+ t.cancel()
+ await asyncio.gather(*rollout_tasks, return_exceptions=True)
+
+ training_task.cancel()
+
+ await shutdown()
+
+
+if __name__ == "__main__":
+
+ @parse
+ def _main(cfg):
+ asyncio.run(main(cfg))
+
+ _main() # @parse grabs the cfg from CLI
diff --git a/apps/distillation/qwen3_distillation.yaml b/apps/distillation/qwen3_distillation.yaml
new file mode 100644
index 000000000..c4f04e2c6
--- /dev/null
+++ b/apps/distillation/qwen3_distillation.yaml
@@ -0,0 +1,142 @@
+# Online Distillation: Student (Qwen3-1.7B) learns from Teacher (Qwen3-32B)
+# >>> python -m apps.distillation.main --config apps/distillation/qwen3_distillation.yaml
+
+# Global configuration
+group_size: 8
+local_batch_size: 16 # per-device batch size
+max_req_tokens: 1024
+max_res_tokens: 1024
+student_model: "Qwen/Qwen3-1.7B"
+teacher_model: "Qwen/Qwen3-32B"
+off_by_n: 1 # Off by one by default
+
+# Main loop configuration
+rollout_threads: 1 # Recommended to set equal to student_generator.num_replicas
+
+# Observability configuration
+metric_logging:
+ wandb:
+ project: distillation-training
+ group: distillation_exp_${oc.env:USER}
+ logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
+ console:
+ logging_mode: global_reduce
+
+# Dataset configuration
+dataset:
+ path: "openai/gsm8k"
+ revision: "main"
+ data_split: "train"
+ streaming: true
+ model: ${student_model} # Use student model tokenizer
+
+# Student Generator configuration (vLLM)
+student_generator:
+ engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
+ model: ${student_model}
+ tensor_parallel_size: 1
+ pipeline_parallel_size: 1
+ enforce_eager: false
+ sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
+ n: ${group_size}
+ max_tokens: ${max_res_tokens}
+ temperature: 1.0
+ top_p: 1.0
+
+# Trainer configuration (Student model training)
+trainer:
+ model:
+ name: qwen3
+ flavor: 1.7B
+ hf_assets_path: hf://${student_model}
+ optimizer:
+ name: AdamW
+ lr: 1e-5
+ eps: 1e-8
+ lr_scheduler:
+ warmup_steps: 1
+ training:
+ local_batch_size: ${local_batch_size}
+ seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
+ max_norm: 1.0
+ steps: 1000000
+ dtype: bfloat16
+ gc_freq: 1
+ compile:
+ enable: false
+ parallelism:
+ data_parallel_replicate_degree: 1
+ data_parallel_shard_degree: 1
+ tensor_parallel_degree: 1
+ pipeline_parallel_degree: 1
+ context_parallel_degree: 1
+ expert_parallel_degree: 1
+ disable_loss_parallel: true
+ checkpoint:
+ enable: true
+ folder: ./checkpoint # The folder to save checkpoints to.
+ initial_load_path: hf://${student_model} # The path to load the initial checkpoint from. Ignored if `folder` exists.
+ initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo
+ last_save_in_hf: true
+ interval: 500
+ async_mode: "disabled"
+ activation_checkpoint:
+ mode: selective
+ selective_ac_option: op
+
+# Replay buffer configuration
+replay_buffer:
+ batch_size: ${local_batch_size}
+ max_policy_age: ${off_by_n}
+ dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
+
+# Teacher model configuration (frozen, larger model)
+teacher_model:
+ model:
+ name: qwen3
+ flavor: 32B
+ hf_assets_path: hf://${teacher_model}
+ training:
+ seq_len: ${trainer.training.seq_len}
+ dtype: bfloat16
+ gc_freq: 1
+ compile:
+ enable: false
+ parallelism:
+ data_parallel_replicate_degree: 1
+ data_parallel_shard_degree: 1
+ tensor_parallel_degree: 4 # 32B model needs more parallelism
+ pipeline_parallel_degree: 1
+ context_parallel_degree: 1
+ expert_parallel_degree: 1
+ checkpoint:
+ enable: true
+ initial_load_path: hf://${teacher_model}
+ initial_load_in_hf: true
+
+# All resource allocations
+services:
+ student_generator:
+ procs: ${student_generator.engine_args.tensor_parallel_size}
+ num_replicas: 1
+ mesh_name: student_generator
+ with_gpus: true
+ teacher_model:
+ procs: ${teacher_model.parallelism.tensor_parallel_degree}
+ num_replicas: 1
+ mesh_name: teacher_model
+ with_gpus: true
+
+actors:
+ dataset:
+ procs: 1
+ with_gpus: false
+ mesh_name: dataset
+ trainer:
+ procs: 1
+ with_gpus: true
+ mesh_name: trainer
+ replay_buffer:
+ procs: 1
+ with_gpus: false
+ mesh_name: replay_buffer