From 4260a8bbae68e2e22bc08468a446295d04435dd1 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 4 Nov 2025 17:05:51 -0800 Subject: [PATCH 1/3] Add online distillation app MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This adds a new online distillation app that demonstrates how to use a smaller student model to learn from a larger teacher model via KL divergence on generated completions. Key components: - Student model (Qwen3-1.7B): Generates rollouts and gets trained - Teacher model (Qwen3-32B): Frozen, provides target logprobs for distillation - Loss: Pure KL divergence between student and teacher distributions - No rewards or advantages - direct distillation objective Implementation based on apps/grpo/main.py with key differences: - Removed RewardActor and ComputeAdvantages - Replaced GRPO loss with distillation_loss (KL divergence) - Simplified Episode dataclass (no reward/advantage fields) - Renamed policy → student_generator, ref_model → teacher_model for clarity Usage: python -m apps.distillation.main --config apps/distillation/qwen3_distillation.yaml Test Plan: - Verified Python syntax with py_compile - Config follows same pattern as GRPO configs --- apps/distillation/main.py | 442 ++++++++++++++++++++++ apps/distillation/qwen3_distillation.yaml | 142 +++++++ 2 files changed, 584 insertions(+) create mode 100644 apps/distillation/main.py create mode 100644 apps/distillation/qwen3_distillation.yaml diff --git a/apps/distillation/main.py b/apps/distillation/main.py new file mode 100644 index 000000000..dd5ceb28a --- /dev/null +++ b/apps/distillation/main.py @@ -0,0 +1,442 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Usage: python -m apps.distillation.main --config apps/distillation/qwen3_distillation.yaml + +import asyncio +import time +import uuid +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn.functional as F +import torchstore as ts +from datasets import load_dataset +from forge.actors._torchstore_utils import ( + get_dcp_whole_state_dict_key, + get_param_prefix, +) +from forge.actors.generator import Generator +from forge.actors.reference_model import ReferenceModel +from forge.actors.replay_buffer import ReplayBuffer +from forge.actors.trainer import RLTrainer +from forge.controller.actor import ForgeActor +from forge.controller.provisioner import init_provisioner, shutdown +from forge.data_models.completion import Completion +from forge.observability.metric_actors import get_or_create_metric_logger +from forge.observability.metrics import record_metric, Reduce +from forge.observability.perf_tracker import Tracer + +from forge.types import LauncherConfig, ProvisionerConfig +from forge.util.config import parse +from forge.util.ops import compute_logprobs +from monarch.actor import endpoint +from omegaconf import DictConfig +from vllm.transformers_utils.tokenizer import get_tokenizer + + +@dataclass +class Episode: + episode_id: str + pad_id: int + request_len: int + response_len: int + target: Any | None = None + # Processed data + completion: Completion | None = None + teacher_logprobs: torch.Tensor | None = None + + @property + def policy_version(self) -> int | None: + return self.completion.generator_version + + @property + def request_tensor(self) -> torch.Tensor: + tensor: torch.Tensor = self.completion.prompt_ids.to(torch.long) + if tensor.shape[0] < self.request_len: # left pad + diff = self.request_len - tensor.shape[0] + tensor = F.pad(tensor, (diff, 0), value=self.pad_id) + return tensor + + @property + def response_tensor(self) -> torch.Tensor: + tensor: torch.Tensor = self.completion.token_ids.to(torch.long) + if tensor.shape[0] < self.response_len: # right pad + diff = self.response_len - tensor.shape[0] + tensor = F.pad(tensor, (0, diff), value=self.pad_id) + return tensor + + +# Represents the group (G) of episodes +Group = list[Episode] + +# Represents the Student Model to generate data and train +StudentGenerator = Generator + + +def collate( + batches: list[Group], +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """ + Collates a list of batches into a single batch of inputs and targets for distillation. + Each batch is a list of episodes, and each episode is a dict of tensors. + """ + inputs = [] + targets = [] + for batch in batches: + request = [e.request_tensor for e in batch] + request = torch.stack(request) # [b x s] + + response = [e.response_tensor for e in batch] + response = torch.stack(response) # [b x s] + + teacher_logprobs = [e.teacher_logprobs for e in batch] + teacher_logprobs = torch.stack(teacher_logprobs).squeeze() # [b x s] + + pad_id = batch[0].pad_id + mask = response != pad_id + + input = {"tokens": torch.cat([request, response], dim=1)} + target = { + "response": response, + "teacher_logprobs": teacher_logprobs, + "padding_mask": mask, + } + inputs.append(input) + targets.append(target) + return inputs, targets + + +def distillation_loss( + logits: torch.Tensor, + response: torch.Tensor, + teacher_logprobs: torch.Tensor, + padding_mask: torch.Tensor, +) -> torch.Tensor: + """ + Online distillation loss using KL divergence between student and teacher. + + Args: + logits: Student model logits [batch_size, seq_len, vocab_size] + response: Response token ids [batch_size, seq_len] + teacher_logprobs: Teacher log probabilities [batch_size, seq_len] + padding_mask: Mask for valid (non-padding) tokens [batch_size, seq_len] + + Returns: + KL divergence loss averaged over valid tokens + """ + student_logprobs: torch.Tensor = compute_logprobs(logits, response) + + # Forward KL: KL(teacher || student) = E_teacher[log(teacher) - log(student)] + # This encourages student to cover all modes of teacher + kl = teacher_logprobs.exp() * (teacher_logprobs - student_logprobs) + + # Average over valid (non-padded) tokens + per_token_loss = kl + loss = ( + ((per_token_loss * padding_mask).sum(dim=1)) + / (padding_mask.sum(dim=1).clamp(min=1.0)) + ).mean() + + return loss + + +@dataclass +class DatasetActor(ForgeActor): + """Actor wrapper for HuggingFace dataset to provide async interface.""" + + path: str = "openai/gsm8k" + revision: str = "main" + data_split: str = "train" + streaming: bool = True + model: str = "Qwen/Qwen3-1.7B" + + @endpoint + def setup(self): + self._tokenizer = get_tokenizer(self.model) + self._epoch = 0 + + def gsm8k_transform(sample): + system_prompt = """ + Put all your scratchpad work between and tags. + Your final answer should be between and tags otherwise it will not be scored. + """ + request: str = sample["question"] + as_chat = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": request}, + ] + formatted_request = self._tokenizer.apply_chat_template( + as_chat, + tokenize=False, + add_generation_prompt=True, + ) + target: str = sample["answer"] + formatted_target = target.split("#### ")[1] + return {"request": formatted_request, "target": formatted_target} + + self._base_dataset = load_dataset( + self.path, self.revision, split=self.data_split, streaming=self.streaming + ) + self._base_dataset = self._base_dataset.map(gsm8k_transform) + self._base_dataset = self._base_dataset.shuffle() + self._iterator = iter(self._base_dataset) + + @endpoint + async def sample(self) -> dict[str, str] | None: + try: + sample = next(self._iterator) + + record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM) + record_metric( + "dataset/sample/avg_sample_len", + len(sample["request"]), + Reduce.MEAN, + ) + record_metric("dataset/sample/current_epoch", self._epoch, Reduce.MAX) + + return sample + except StopIteration: + # Restart iterator for next epoch with reshuffling + self._epoch += 1 + print( + f"Dataset epoch {self._epoch - 1} completed. Starting epoch {self._epoch}" + ) + self._base_dataset.set_epoch(self._epoch) + self._iterator = iter(self._base_dataset) + return next(self._iterator) + + @endpoint + async def pad_token(self): + return self._tokenizer.pad_token_id + + +async def drop_weights(version: int): + print(f"Dropping weights @ version {version}") + start_time = time.perf_counter() + prefix = get_param_prefix(version) + matching_keys = await ts.keys(prefix) + # TODO: once we have something like `get_meta()` in torchstore, we can just + # query the type of the object instead of relying on keys. + dcp_key = get_dcp_whole_state_dict_key(version) + if dcp_key in matching_keys: + dcp_handle = await ts.get(dcp_key) + dcp_handle.drop() + for key in matching_keys: + await ts.delete(key) + elapsed = time.perf_counter() - start_time + print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds") + + +async def main(cfg: DictConfig): + """Main online distillation training loop with rollout and training processes.""" + group_size = cfg.group_size + max_req_tokens = cfg.max_req_tokens + max_res_tokens = cfg.max_res_tokens + + # ---- Global setups ---- # + provisioner = None + if cfg.get("provisioner", None) is not None: + provisioner = await init_provisioner( + ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) + ) + else: + provisioner = await init_provisioner() + + metric_logging_cfg = cfg.get("metric_logging", {}) + mlogger = await get_or_create_metric_logger(process_name="Controller") + await mlogger.init_backends.call_one(metric_logging_cfg) + + # ---- Setup services ---- # + + ( + dataloader, + student_generator, + trainer, + replay_buffer, + teacher_model, + ) = await asyncio.gather( + DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset), + StudentGenerator.options(**cfg.services.student_generator).as_service( + **cfg.student_generator + ), + RLTrainer.options(**cfg.actors.trainer).as_actor( + **cfg.trainer, loss=distillation_loss + ), + ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor( + **cfg.replay_buffer, collate=collate + ), + ReferenceModel.options(**cfg.services.teacher_model).as_service( + **cfg.teacher_model + ), + ) + + # Set max_steps to the configured value, or -1 if not specified or Null + max_steps = cfg.trainer.training.steps or -1 + + print("All services initialized successfully!") + shutdown_event = asyncio.Event() + # Here we spawn a torchstore storage volume per trainer process. + # We initialize after service initialization because torchstore currently + # requires access to the underlying proc meshes in the local rank strategy. + # We should be able to hide this in the future. + # TODO: support multiple host meshes + trainer_num_procs = cfg.actors.trainer["procs"] + trainer_host_mesh_name = cfg.actors.trainer["mesh_name"] + trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name) + await ts.initialize( + mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}), + strategy=ts.LocalRankStrategy(), + ) + print("Torchstore successfully initialized with local rank strategy") + + # ---- Core distillation loops ---- # + async def continuous_rollouts(): + rollout_count = 0 + pad_id = await dataloader.pad_token.call_one() + while not shutdown_event.is_set(): + t = Tracer("main_perf/continuous_rollouts") + t.start() + sample = await dataloader.sample.call_one() + if sample is None: + print("Dataloader is empty, exiting continuous rollout") + return + + t.step("data_loading") + + prompt, target = sample["request"], sample["target"] + responses: list[Completion] = await student_generator.generate.route(prompt) + t.step("student_generation") + + # Construct episodes with teacher logprobs + episodes = [] + input_ids = torch.ones( + (group_size, max_req_tokens + max_res_tokens), + dtype=torch.long, + ) + for i, response in enumerate(responses): + episode = Episode( + episode_id=str(uuid.uuid4()), + pad_id=pad_id, + request_len=max_req_tokens, + response_len=max_res_tokens, + target=target, + completion=response, + ) + episodes.append(episode) + + # Build input_ids for teacher logprobs + input_ids[i, :max_req_tokens] = episode.request_tensor + input_ids[i, max_req_tokens:] = episode.response_tensor + + t.step("episode_construction") + + # Get teacher logprobs on student-generated completions + teacher_logprobs = await teacher_model.forward.route( + input_ids, max_req_tokens, return_logprobs=True + ) + t.step("teacher_model_calculate_logprobs") + + for i, episode in enumerate(episodes): + episode.teacher_logprobs = teacher_logprobs[i] + await replay_buffer.add.call_one(episode) + + del teacher_logprobs, input_ids + + rollout_count += 1 + record_metric( + "main/continuous_rollouts/count_rollout_iterations", 1, Reduce.SUM + ) + t.stop() + + async def continuous_training(): + training_step = 0 + restart_tracer = True # Flag to control when to restart tracer + + while max_steps == -1 or training_step < max_steps: + # Restart tracer when needed (initial start or after completing a training step) + # Otherwise, we cannot measure time waiting for buffer + if restart_tracer: + t = Tracer("main_perf/continuous_training") + t.start() + restart_tracer = False + + batch = await replay_buffer.sample.call_one( + curr_policy_version=training_step + ) + if batch is None: + await asyncio.sleep(0.1) + else: + t.step("waiting_for_buffer") + + inputs, targets = batch + await trainer.train_step.call(inputs, targets) + training_step += 1 + t.step("train_step") + + await trainer.push_weights.call(training_step) + t.step("push_weights") + + await student_generator.update_weights.fanout(training_step) + t.step("update_weights") + + if training_step >= 2: + await drop_weights(training_step - 1) + t.step("drop_weights") + + t.stop() + restart_tracer = True + + # Flush metrics every training step to WandB + await mlogger.flush.call_one(training_step) + + print( + f"Reached training limit ({max_steps} steps). Exiting continuous_training loop." + ) + + num_rollout_threads = cfg.get("rollout_threads", 1) + num_training_threads = cfg.get("training_threads", 1) + print( + f"Starting online distillation with {num_rollout_threads} rollout threads, {num_training_threads} training threads" + ) + rollout_tasks = [ + asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads) + ] + training_task = asyncio.create_task(continuous_training()) + + try: + await training_task + except KeyboardInterrupt: + print("Training interrupted by user") + finally: + print("Shutting down... (this may take a few seconds)") + shutdown_event.set() + + try: + # Give rollouts up to 5s to finish naturally + await asyncio.wait_for( + asyncio.gather(*rollout_tasks, return_exceptions=True), + timeout=5, + ) + except asyncio.TimeoutError: + print("Timeout waiting for rollouts; forcing cancellation...") + for t in rollout_tasks: + t.cancel() + await asyncio.gather(*rollout_tasks, return_exceptions=True) + + training_task.cancel() + + await shutdown() + + +if __name__ == "__main__": + + @parse + def _main(cfg): + asyncio.run(main(cfg)) + + _main() # @parse grabs the cfg from CLI diff --git a/apps/distillation/qwen3_distillation.yaml b/apps/distillation/qwen3_distillation.yaml new file mode 100644 index 000000000..c4f04e2c6 --- /dev/null +++ b/apps/distillation/qwen3_distillation.yaml @@ -0,0 +1,142 @@ +# Online Distillation: Student (Qwen3-1.7B) learns from Teacher (Qwen3-32B) +# >>> python -m apps.distillation.main --config apps/distillation/qwen3_distillation.yaml + +# Global configuration +group_size: 8 +local_batch_size: 16 # per-device batch size +max_req_tokens: 1024 +max_res_tokens: 1024 +student_model: "Qwen/Qwen3-1.7B" +teacher_model: "Qwen/Qwen3-32B" +off_by_n: 1 # Off by one by default + +# Main loop configuration +rollout_threads: 1 # Recommended to set equal to student_generator.num_replicas + +# Observability configuration +metric_logging: + wandb: + project: distillation-training + group: distillation_exp_${oc.env:USER} + logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce + console: + logging_mode: global_reduce + +# Dataset configuration +dataset: + path: "openai/gsm8k" + revision: "main" + data_split: "train" + streaming: true + model: ${student_model} # Use student model tokenizer + +# Student Generator configuration (vLLM) +student_generator: + engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs + model: ${student_model} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + enforce_eager: false + sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams + n: ${group_size} + max_tokens: ${max_res_tokens} + temperature: 1.0 + top_p: 1.0 + +# Trainer configuration (Student model training) +trainer: + model: + name: qwen3 + flavor: 1.7B + hf_assets_path: hf://${student_model} + optimizer: + name: AdamW + lr: 1e-5 + eps: 1e-8 + lr_scheduler: + warmup_steps: 1 + training: + local_batch_size: ${local_batch_size} + seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens + max_norm: 1.0 + steps: 1000000 + dtype: bfloat16 + gc_freq: 1 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + disable_loss_parallel: true + checkpoint: + enable: true + folder: ./checkpoint # The folder to save checkpoints to. + initial_load_path: hf://${student_model} # The path to load the initial checkpoint from. Ignored if `folder` exists. + initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo + last_save_in_hf: true + interval: 500 + async_mode: "disabled" + activation_checkpoint: + mode: selective + selective_ac_option: op + +# Replay buffer configuration +replay_buffer: + batch_size: ${local_batch_size} + max_policy_age: ${off_by_n} + dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree + +# Teacher model configuration (frozen, larger model) +teacher_model: + model: + name: qwen3 + flavor: 32B + hf_assets_path: hf://${teacher_model} + training: + seq_len: ${trainer.training.seq_len} + dtype: bfloat16 + gc_freq: 1 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 4 # 32B model needs more parallelism + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + checkpoint: + enable: true + initial_load_path: hf://${teacher_model} + initial_load_in_hf: true + +# All resource allocations +services: + student_generator: + procs: ${student_generator.engine_args.tensor_parallel_size} + num_replicas: 1 + mesh_name: student_generator + with_gpus: true + teacher_model: + procs: ${teacher_model.parallelism.tensor_parallel_degree} + num_replicas: 1 + mesh_name: teacher_model + with_gpus: true + +actors: + dataset: + procs: 1 + with_gpus: false + mesh_name: dataset + trainer: + procs: 1 + with_gpus: true + mesh_name: trainer + replay_buffer: + procs: 1 + with_gpus: false + mesh_name: replay_buffer From e301b4dd7fa49b690d455dad7f9d7ae0a12418c9 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 4 Nov 2025 17:12:27 -0800 Subject: [PATCH 2/3] Use reverse KL divergence for distillation loss Changed from forward KL to reverse KL to match the standard online distillation objective: KL(student || teacher) = E_{x~student}[log student - log teacher] This is the natural choice for online distillation because: 1. We sample from the student policy during rollouts 2. Reverse KL is mode-seeking (student focuses on teacher's high-probability modes) 3. Simpler formula: just student_logprobs - teacher_logprobs 4. More stable gradients The formula is now: kl = student_logprobs - teacher_logprobs instead of the previous forward KL: kl = teacher_logprobs.exp() * (teacher_logprobs - student_logprobs) --- apps/distillation/main.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/apps/distillation/main.py b/apps/distillation/main.py index dd5ceb28a..478465f7e 100644 --- a/apps/distillation/main.py +++ b/apps/distillation/main.py @@ -118,7 +118,7 @@ def distillation_loss( padding_mask: torch.Tensor, ) -> torch.Tensor: """ - Online distillation loss using KL divergence between student and teacher. + Online distillation loss using reverse KL divergence between student and teacher. Args: logits: Student model logits [batch_size, seq_len, vocab_size] @@ -127,13 +127,14 @@ def distillation_loss( padding_mask: Mask for valid (non-padding) tokens [batch_size, seq_len] Returns: - KL divergence loss averaged over valid tokens + Reverse KL divergence loss averaged over valid tokens """ student_logprobs: torch.Tensor = compute_logprobs(logits, response) - # Forward KL: KL(teacher || student) = E_teacher[log(teacher) - log(student)] - # This encourages student to cover all modes of teacher - kl = teacher_logprobs.exp() * (teacher_logprobs - student_logprobs) + # Reverse KL: KL(student || teacher) = E_{x~student}[log student - log teacher] + # This is mode-seeking: student focuses on teacher's high-probability modes + # Since we sample from student in online distillation, this is the natural choice + kl = student_logprobs - teacher_logprobs # Average over valid (non-padded) tokens per_token_loss = kl From 6d73d522cdf6970d35a40647ebb002def9f524a3 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 4 Nov 2025 17:16:28 -0800 Subject: [PATCH 3/3] Reformulate distillation as importance sampling with negative KL as reward Changed the loss formulation to: reward = -reverse_kl = -(sampled_logprobs - teacher_logprobs) loss = -E[importance_weight * reward] where: - reverse_kl is DETACHED (no backprop through it) - importance_weight = exp(logprobs - logprobs.detach()) This treats distillation as a reward-based objective where the "reward" is how well the student matches the teacher (negative KL). The gradient flows through the importance sampling term only, not through the KL itself. This is similar to GRPO's policy gradient term but without the KL penalty: per_token_policy_loss = exp(logprobs - logprobs.detach()) * reward loss = -per_token_policy_loss Key difference from previous implementation: - Before: Direct KL minimization with backprop through both student and teacher logprobs - Now: REINFORCE-style gradient with detached KL as reward signal --- apps/distillation/main.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/apps/distillation/main.py b/apps/distillation/main.py index 478465f7e..55eda8af9 100644 --- a/apps/distillation/main.py +++ b/apps/distillation/main.py @@ -118,7 +118,17 @@ def distillation_loss( padding_mask: torch.Tensor, ) -> torch.Tensor: """ - Online distillation loss using reverse KL divergence between student and teacher. + Online distillation loss using importance sampling with negative KL as reward. + + Formulation: + reward = -reverse_kl = -(sampled_logprobs - teacher_logprobs) + loss = -E[importance_weight * reward] + + where importance_weight = exp(logprobs - logprobs.detach()) and the KL term + is detached (no backprop through it). + + This is similar to GRPO's policy gradient but without the KL penalty term, + treating the negative KL as a reward signal. Args: logits: Student model logits [batch_size, seq_len, vocab_size] @@ -127,17 +137,26 @@ def distillation_loss( padding_mask: Mask for valid (non-padding) tokens [batch_size, seq_len] Returns: - Reverse KL divergence loss averaged over valid tokens + Importance sampling loss with negative KL as reward """ student_logprobs: torch.Tensor = compute_logprobs(logits, response) - # Reverse KL: KL(student || teacher) = E_{x~student}[log student - log teacher] - # This is mode-seeking: student focuses on teacher's high-probability modes - # Since we sample from student in online distillation, this is the natural choice - kl = student_logprobs - teacher_logprobs + # Compute reward as negative reverse KL (detached) + # reverse_kl = sampled_logprobs - teacher_logprobs + # reward = -reverse_kl (higher reward when student matches teacher better) + reverse_kl = student_logprobs.detach() - teacher_logprobs + reward = -reverse_kl + + # Importance sampling loss (like GRPO policy gradient, but without KL penalty) + # importance_weight = exp(current_logprobs - sampled_logprobs) + per_token_policy_loss = ( + torch.exp(student_logprobs - student_logprobs.detach()) * reward + ) + + # Negative sign because we want to maximize reward + per_token_loss = -per_token_policy_loss # Average over valid (non-padded) tokens - per_token_loss = kl loss = ( ((per_token_loss * padding_mask).sum(dim=1)) / (padding_mask.sum(dim=1).clamp(min=1.0))