From 9fa4c46d367bbe63ebb1b641813a2f0e56fc8021 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 12 Nov 2025 21:15:33 +0000 Subject: [PATCH 01/28] activation distillation: first draft --- fast_llm/layers/block/config.py | 5 +++ fast_llm/layers/decoder/block.py | 33 +++++++++++++++---- fast_llm/layers/language_model/config.py | 8 +++++ fast_llm/layers/language_model/head.py | 33 +++++++++++++++++-- .../layers/language_model/language_model.py | 3 ++ fast_llm/models/gpt/model.py | 23 +++++++++++-- 6 files changed, 95 insertions(+), 10 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index f3e93ede..2fb2fedf 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -37,6 +37,11 @@ class BlockKwargs: sequence_lengths = "sequence_lengths" # TODO: Belongs elsewhere? grad_output = "grad_output" + root = "_root_kwargs" + activation_distillation_storage = "activation_distillation_storage" + activation_distillation_targets = "activation_distillation_targets" + activation_distillation_total = "activation_distillation_total" + activation_distillation_count = "activation_distillation_count" @config_class(registry=True) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 8b19db66..5081457b 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -3,6 +3,7 @@ import typing import torch +import torch.nn.functional as F from fast_llm.core.distributed import set_generator from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig @@ -14,6 +15,7 @@ from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -136,13 +138,32 @@ def forward( if self._debug.enabled: self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = self.mixer(hidden_states, kwargs) - if self._debug.enabled: - self._debug( - hidden_states if bias is None else hidden_states + bias, - "mixer output", - kwargs[BlockKwargs.hidden_dims], - kwargs, + mixer_output = hidden_states if bias is None else hidden_states + bias + root_kwargs = kwargs.get(BlockKwargs.root, kwargs) + # Teacher populates mixer activations for distillation. + activation_storage = root_kwargs.get(BlockKwargs.activation_distillation_storage) + if activation_storage is not None: + activation_storage[self.module_name] = mixer_output.detach() + # Student gets teacher activations and computes the activation-level loss. + activation_targets = root_kwargs.get(BlockKwargs.activation_distillation_targets) + if ( + activation_targets is not None + and self.training + and (teacher_output := activation_targets.pop(self.module_name, None)) is not None + ): + # Compare student mixer output with the teacher’s stored activation and accumulate the loss. + teacher_tensor = teacher_output.detach().to(device=mixer_output.device, dtype=mixer_output.dtype) + Assert.eq(teacher_tensor.shape, mixer_output.shape) + activation_loss = F.mse_loss(mixer_output, teacher_tensor) + activation_total = root_kwargs.get(BlockKwargs.activation_distillation_total) + root_kwargs[BlockKwargs.activation_distillation_total] = ( + activation_loss if activation_total is None else activation_total + activation_loss + ) + root_kwargs[BlockKwargs.activation_distillation_count] = ( + root_kwargs.get(BlockKwargs.activation_distillation_count, 0) + 1 ) + if self._debug.enabled: + self._debug(mixer_output, "mixer output", kwargs[BlockKwargs.hidden_dims], kwargs) with set_generator(generator): input_ = self._bias_dropout_add(hidden_states, bias, input_) if self._debug.enabled: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 25fa2d91..6ede28b9 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -162,6 +162,12 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Factor to scale the distillation loss by when using distillation.", hint=FieldHint.feature, ) + activation_distillation_factor: float = Field( + default=0.0, + desc="Factor to scale the activation-level distillation loss by when using distillation.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) logits_scale_factor: float = Field( default=1.0, desc="Multiply output logits by scale factor.", @@ -233,6 +239,8 @@ def _validate(self) -> None: self.language_model_loss_factor = 0.0 super()._validate() assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both + if self.activation_distillation_factor > 0.0 and self.distillation_model is None: + raise ValueError("Activation distillation requires a distillation_model to be configured.") @property def max_prediction_distance(self) -> int: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 4b0e3d10..a6ca0545 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -18,7 +18,7 @@ from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.block.block import Block -from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.block.config import BlockDimNames, BlockKwargs from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( @@ -419,6 +419,19 @@ def _logits_cross_entropy_forward_backward( else: distillation_loss, distillation_grad = None, None + activation_loss = None + root_kwargs = kwargs.get(BlockKwargs.root, kwargs) + activation_total = root_kwargs.get(BlockKwargs.activation_distillation_total) + activation_count = root_kwargs.get(BlockKwargs.activation_distillation_count, 0) + if activation_total is not None and activation_count and self._config.activation_distillation_factor > 0.0: + activation_loss = (activation_total / activation_count) * self._config.activation_distillation_factor + if losses is not None and self._activation_distillation_loss_name in losses: + losses[self._activation_distillation_loss_name].append(activation_loss.detach()) + # Activation targets are no longer needed past this point. + root_kwargs.pop(BlockKwargs.activation_distillation_targets, None) + root_kwargs.pop(BlockKwargs.activation_distillation_total, None) + root_kwargs.pop(BlockKwargs.activation_distillation_count, None) + # TODO: de-allocate earlier. del logits @@ -426,7 +439,7 @@ def _logits_cross_entropy_forward_backward( grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) # TODO: Return individual losses? - loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) + loss = _add_tensors(dpo_loss, lm_loss, distillation_loss, activation_loss) if self.training and losses is not None: if dpo_loss is not None: losses[self._dpo_loss_name].append(dpo_loss.detach()) @@ -472,6 +485,13 @@ def _distillation_loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name + @functools.cached_property + def _activation_distillation_loss_name(self) -> str: + name = "activation_distillation_loss" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: loss_defs = [LossDef(name=self._loss_name, formatted_name=_format_name(self._loss_name), count=count)] if self._config.logit_z_loss: @@ -500,6 +520,15 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: ) ) + if self._config.activation_distillation_factor > 0.0 and self._config.distillation_model is not None: + loss_defs.append( + LossDef( + name=self._activation_distillation_loss_name, + formatted_name=_format_name(self._activation_distillation_loss_name), + count=count, + ) + ) + return loss_defs diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py index 2e46bb57..4bfe34e4 100644 --- a/fast_llm/layers/language_model/language_model.py +++ b/fast_llm/layers/language_model/language_model.py @@ -8,6 +8,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.block import BlockBase +from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding @@ -59,6 +60,8 @@ def get_layers(self) -> list[Layer]: return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers() def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + # Seed a shared root pointer so nested layers (including namespaced ones) can exchange activation distillation state. + kwargs.setdefault(BlockKwargs.root, kwargs) # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? self.embeddings.preprocess(batch, kwargs) self.decoder.preprocess(batch, kwargs) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index efa348ec..8df83c2a 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,7 +10,7 @@ from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.block.config import BlockDimNames, BlockKwargs from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig @@ -176,6 +176,8 @@ def preprocess_batch( ) reference_logits = [{} for _ in preprocessed_meta] + distillation_model = getattr(self._config.head, "distillation_model", None) + activation_distillation_factor = getattr(self._config.head, "activation_distillation_factor", 0.0) for name, reference_model in self._reference_models.items(): reference_preprocessed_meta = [ (tokens_meta, kwargs_meta["reference_models"][name]) for tokens_meta, kwargs_meta in preprocessed_meta @@ -188,8 +190,19 @@ def preprocess_batch( # TODO: Do things work with >1? Assert.eq(len(reference_batch), len(preprocessed_meta), 1) for i, (reference_tokens, reference_kwargs) in enumerate(reference_batch): + if ( + phase != PhaseType.inference + and name == distillation_model + and activation_distillation_factor > 0.0 + ): + reference_kwargs[BlockKwargs.activation_distillation_storage] = {} reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] + if BlockKwargs.activation_distillation_storage in reference_kwargs: + reference_logits[i][f"{name}_activations"] = reference_kwargs[ + BlockKwargs.activation_distillation_storage + ] + del reference_kwargs[BlockKwargs.activation_distillation_storage] token_ids = batch.token_ids if sequence_first: @@ -255,7 +268,13 @@ def preprocess_batch( kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) kwargs[LanguageModelKwargs.labels] = labels - kwargs.update(reference_logits[i]) + reference_payload = reference_logits[i] + kwargs.update(reference_payload) + + if distillation_model is not None and activation_distillation_factor > 0.0: + teacher_key = f"{distillation_model}_activations" + if teacher_key in reference_payload: + kwargs[BlockKwargs.activation_distillation_targets] = reference_payload.pop(teacher_key) if batch.chosen_spans is not None: chosen_valid_spans = [] From 11708ff5713bda914cb32b6cb458a02b895cbc41 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 12 Nov 2025 23:38:48 +0000 Subject: [PATCH 02/28] fix kwargs --- fast_llm/engine/base_model/base_model.py | 1 + .../layers/language_model/language_model.py | 21 ++++++++++-- fast_llm/models/gpt/model.py | 34 ++++++++++--------- 3 files changed, 38 insertions(+), 18 deletions(-) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 5df59d4c..106bea21 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -145,6 +145,7 @@ def preprocess_batch( phase: PhaseType, iteration: int, metrics: dict | None = None, + setup_activation_storage: bool = False, ) -> list[tuple[torch.Tensor, dict]]: # TODO Move batch splitting elsewhere, align interface with LayerBase pass diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py index 4bfe34e4..579eb531 100644 --- a/fast_llm/layers/language_model/language_model.py +++ b/fast_llm/layers/language_model/language_model.py @@ -60,8 +60,25 @@ def get_layers(self) -> list[Layer]: return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers() def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - # Seed a shared root pointer so nested layers (including namespaced ones) can exchange activation distillation state. - kwargs.setdefault(BlockKwargs.root, kwargs) + # TODO: remove root_kwargs + activation_factor = getattr(self.head._config, "activation_distillation_factor", 0.0) + if ( + activation_factor > 0.0 + or BlockKwargs.activation_distillation_targets in kwargs + or BlockKwargs.activation_distillation_storage in kwargs + ): + root_state = kwargs.get(BlockKwargs.root) + if root_state is None or root_state is kwargs: + root_state = {} + kwargs[BlockKwargs.root] = root_state + if BlockKwargs.activation_distillation_targets in kwargs: + root_state[BlockKwargs.activation_distillation_targets] = kwargs[ + BlockKwargs.activation_distillation_targets + ] + if BlockKwargs.activation_distillation_storage in kwargs: + root_state[BlockKwargs.activation_distillation_storage] = kwargs[ + BlockKwargs.activation_distillation_storage + ] # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? self.embeddings.preprocess(batch, kwargs) self.decoder.preprocess(batch, kwargs) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 8df83c2a..25affc49 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -157,6 +157,7 @@ def preprocess_batch( phase: PhaseType, iteration: int, metrics: dict | None = None, + setup_activation_storage: bool = False, ) -> list[tuple[torch.Tensor, dict]]: # TODO Move batch splitting elsewhere, align interface with LayerBase assert self._is_setup @@ -175,27 +176,26 @@ def preprocess_batch( non_blocking=True, ) - reference_logits = [{} for _ in preprocessed_meta] distillation_model = getattr(self._config.head, "distillation_model", None) - activation_distillation_factor = getattr(self._config.head, "activation_distillation_factor", 0.0) + activation_factor = getattr(self._config.head, "activation_distillation_factor", 0.0) + reference_logits: list[dict[str, typing.Any]] | None = None + reference_logits = [{} for _ in preprocessed_meta] for name, reference_model in self._reference_models.items(): reference_preprocessed_meta = [ (tokens_meta, kwargs_meta["reference_models"][name]) for tokens_meta, kwargs_meta in preprocessed_meta ] reference_batch = reference_model.fast_llm_model.base_model.preprocess_batch( - batch, reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration + batch, + reference_preprocessed_meta, + phase=PhaseType.inference, + iteration=iteration, + setup_activation_storage=activation_factor > 0.0, ) # TODO: Do things work with >1? Assert.eq(len(reference_batch), len(preprocessed_meta), 1) for i, (reference_tokens, reference_kwargs) in enumerate(reference_batch): - if ( - phase != PhaseType.inference - and name == distillation_model - and activation_distillation_factor > 0.0 - ): - reference_kwargs[BlockKwargs.activation_distillation_storage] = {} reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] if BlockKwargs.activation_distillation_storage in reference_kwargs: @@ -268,13 +268,13 @@ def preprocess_batch( kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) kwargs[LanguageModelKwargs.labels] = labels - reference_payload = reference_logits[i] - kwargs.update(reference_payload) - - if distillation_model is not None and activation_distillation_factor > 0.0: - teacher_key = f"{distillation_model}_activations" - if teacher_key in reference_payload: - kwargs[BlockKwargs.activation_distillation_targets] = reference_payload.pop(teacher_key) + if reference_logits is not None: + reference_payload = reference_logits[i] + kwargs.update(reference_payload) + if distillation_model is not None and activation_factor > 0.0: + teacher_key = f"{distillation_model}_activations" + if teacher_key in reference_payload: + kwargs[BlockKwargs.activation_distillation_targets] = reference_payload.pop(teacher_key) if batch.chosen_spans is not None: chosen_valid_spans = [] @@ -307,6 +307,8 @@ def preprocess_batch( rejected_valid_spans.append(valid_spans) kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans + if setup_activation_storage: + kwargs.setdefault(BlockKwargs.activation_distillation_storage, {}) self.preprocess(tokens, kwargs) preprocessed.append((tokens, kwargs)) From 943731090bb670406123c97a980ea4f27225e3ad Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 13 Nov 2025 19:33:22 +0000 Subject: [PATCH 03/28] remove count, add auxiliaryLoss hook --- fast_llm/layers/block/config.py | 1 - fast_llm/layers/decoder/block.py | 15 ++++++++------- fast_llm/layers/language_model/head.py | 6 ++---- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 2fb2fedf..45dbe495 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -41,7 +41,6 @@ class BlockKwargs: activation_distillation_storage = "activation_distillation_storage" activation_distillation_targets = "activation_distillation_targets" activation_distillation_total = "activation_distillation_total" - activation_distillation_count = "activation_distillation_count" @config_class(registry=True) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 5081457b..d5d87cdb 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -3,7 +3,6 @@ import typing import torch -import torch.nn.functional as F from fast_llm.core.distributed import set_generator from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig @@ -12,6 +11,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig from fast_llm.tensor import TensorMeta @@ -139,6 +139,7 @@ def forward( self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = self.mixer(hidden_states, kwargs) mixer_output = hidden_states if bias is None else hidden_states + bias + # TODO: wrap in method: activation_distillation root_kwargs = kwargs.get(BlockKwargs.root, kwargs) # Teacher populates mixer activations for distillation. activation_storage = root_kwargs.get(BlockKwargs.activation_distillation_storage) @@ -154,14 +155,14 @@ def forward( # Compare student mixer output with the teacher’s stored activation and accumulate the loss. teacher_tensor = teacher_output.detach().to(device=mixer_output.device, dtype=mixer_output.dtype) Assert.eq(teacher_tensor.shape, mixer_output.shape) - activation_loss = F.mse_loss(mixer_output, teacher_tensor) + # TODO: handle sequence-first? + activation_loss = torch.mean(torch.norm(mixer_output - teacher_tensor, p=2, dim=(1, 2))) + mixer_output = AuxiliaryLoss.apply(mixer_output, activation_loss, 1.0) activation_total = root_kwargs.get(BlockKwargs.activation_distillation_total) - root_kwargs[BlockKwargs.activation_distillation_total] = ( - activation_loss if activation_total is None else activation_total + activation_loss - ) - root_kwargs[BlockKwargs.activation_distillation_count] = ( - root_kwargs.get(BlockKwargs.activation_distillation_count, 0) + 1 + activation_total = ( + activation_loss.detach() if activation_total is None else activation_total + activation_loss.detach() ) + root_kwargs[BlockKwargs.activation_distillation_total] = activation_total if self._debug.enabled: self._debug(mixer_output, "mixer output", kwargs[BlockKwargs.hidden_dims], kwargs) with set_generator(generator): diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index a6ca0545..290f5014 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -422,15 +422,13 @@ def _logits_cross_entropy_forward_backward( activation_loss = None root_kwargs = kwargs.get(BlockKwargs.root, kwargs) activation_total = root_kwargs.get(BlockKwargs.activation_distillation_total) - activation_count = root_kwargs.get(BlockKwargs.activation_distillation_count, 0) - if activation_total is not None and activation_count and self._config.activation_distillation_factor > 0.0: - activation_loss = (activation_total / activation_count) * self._config.activation_distillation_factor + if activation_total is not None and self._config.activation_distillation_factor > 0.0: + activation_loss = activation_total * self._config.activation_distillation_factor if losses is not None and self._activation_distillation_loss_name in losses: losses[self._activation_distillation_loss_name].append(activation_loss.detach()) # Activation targets are no longer needed past this point. root_kwargs.pop(BlockKwargs.activation_distillation_targets, None) root_kwargs.pop(BlockKwargs.activation_distillation_total, None) - root_kwargs.pop(BlockKwargs.activation_distillation_count, None) # TODO: de-allocate earlier. del logits From d3ac9646e75ea5754c9b1d8328275b28e966c203 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 13 Nov 2025 19:40:33 +0000 Subject: [PATCH 04/28] fix auxiliary loss --- fast_llm/layers/decoder/block.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index d5d87cdb..92f63872 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -157,7 +157,10 @@ def forward( Assert.eq(teacher_tensor.shape, mixer_output.shape) # TODO: handle sequence-first? activation_loss = torch.mean(torch.norm(mixer_output - teacher_tensor, p=2, dim=(1, 2))) - mixer_output = AuxiliaryLoss.apply(mixer_output, activation_loss, 1.0) + # Backward hooks + hidden_states = AuxiliaryLoss.apply(hidden_states, activation_loss, 1.0) + bias = AuxiliaryLoss.apply(bias, activation_loss, 1.0) if bias is not None else None + # Logging activation_total = root_kwargs.get(BlockKwargs.activation_distillation_total) activation_total = ( activation_loss.detach() if activation_total is None else activation_total + activation_loss.detach() From 56fc8db9d3caac25e497ea8ef3e4b03ebac598b7 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 13 Nov 2025 19:45:26 +0000 Subject: [PATCH 05/28] wrap in method --- fast_llm/layers/decoder/block.py | 58 +++++++++++++++++++------------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 92f63872..6fe00006 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -138,6 +138,40 @@ def forward( if self._debug.enabled: self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = self.mixer(hidden_states, kwargs) + + self.activation_distillation_loss(hidden_states, bias, kwargs) + + if self._debug.enabled: + self._debug( + hidden_states if bias is None else hidden_states + bias, + "mixer output", + kwargs[BlockKwargs.hidden_dims], + kwargs, + ) + with set_generator(generator): + input_ = self._bias_dropout_add(hidden_states, bias, input_) + if self._debug.enabled: + self._debug(input_, "mixer residual", kwargs[BlockKwargs.hidden_dims], kwargs) + hidden_states = self.norm_2(input_) + if self._debug.enabled: + self._debug(hidden_states, "norm 2", kwargs[BlockKwargs.hidden_dims], kwargs) + hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) + if self._debug.enabled: + self._debug( + hidden_states if bias is None else hidden_states + bias, + "MLP output", + kwargs[BlockKwargs.hidden_dims], + kwargs, + ) + with set_generator(generator): + hidden_states = self._bias_dropout_add(hidden_states, bias, input_) + if self._debug.enabled: + self._debug(None, "MLP residual", kwargs[BlockKwargs.hidden_dims], kwargs) + if self._return_input: + hidden_states = torch.stack((fw_input, hidden_states), dim=0) + return hidden_states + + def activation_distillation_loss(self, hidden_states, bias, kwargs): mixer_output = hidden_states if bias is None else hidden_states + bias # TODO: wrap in method: activation_distillation root_kwargs = kwargs.get(BlockKwargs.root, kwargs) @@ -166,30 +200,6 @@ def forward( activation_loss.detach() if activation_total is None else activation_total + activation_loss.detach() ) root_kwargs[BlockKwargs.activation_distillation_total] = activation_total - if self._debug.enabled: - self._debug(mixer_output, "mixer output", kwargs[BlockKwargs.hidden_dims], kwargs) - with set_generator(generator): - input_ = self._bias_dropout_add(hidden_states, bias, input_) - if self._debug.enabled: - self._debug(input_, "mixer residual", kwargs[BlockKwargs.hidden_dims], kwargs) - hidden_states = self.norm_2(input_) - if self._debug.enabled: - self._debug(hidden_states, "norm 2", kwargs[BlockKwargs.hidden_dims], kwargs) - hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) - if self._debug.enabled: - self._debug( - hidden_states if bias is None else hidden_states + bias, - "MLP output", - kwargs[BlockKwargs.hidden_dims], - kwargs, - ) - with set_generator(generator): - hidden_states = self._bias_dropout_add(hidden_states, bias, input_) - if self._debug.enabled: - self._debug(None, "MLP residual", kwargs[BlockKwargs.hidden_dims], kwargs) - if self._return_input: - hidden_states = torch.stack((fw_input, hidden_states), dim=0) - return hidden_states def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Add marginal compute? (normalization, bias_dropout_add) From 5d75f01469bce67da86f42b87695ee2365da6ce3 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 13 Nov 2025 19:47:33 +0000 Subject: [PATCH 06/28] fixes --- fast_llm/layers/decoder/block.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 6fe00006..7b058f5f 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -139,7 +139,7 @@ def forward( self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = self.mixer(hidden_states, kwargs) - self.activation_distillation_loss(hidden_states, bias, kwargs) + hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs) if self._debug.enabled: self._debug( @@ -172,8 +172,10 @@ def forward( return hidden_states def activation_distillation_loss(self, hidden_states, bias, kwargs): + """ + Maybe apply activation distillation loss and setup backward hooks + """ mixer_output = hidden_states if bias is None else hidden_states + bias - # TODO: wrap in method: activation_distillation root_kwargs = kwargs.get(BlockKwargs.root, kwargs) # Teacher populates mixer activations for distillation. activation_storage = root_kwargs.get(BlockKwargs.activation_distillation_storage) @@ -200,6 +202,7 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs): activation_loss.detach() if activation_total is None else activation_total + activation_loss.detach() ) root_kwargs[BlockKwargs.activation_distillation_total] = activation_total + return hidden_states, bias def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Add marginal compute? (normalization, bias_dropout_add) From f1bfca967743ca53a6530f918bab4bd57a6a1b0c Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 14 Nov 2025 21:04:57 +0000 Subject: [PATCH 07/28] move activation distillation loss reporting to decoder block --- fast_llm/layers/decoder/block.py | 33 ++++++++++++++++--- fast_llm/layers/decoder/config.py | 16 +++++++++ fast_llm/layers/language_model/config.py | 8 ----- fast_llm/layers/language_model/head.py | 31 ++--------------- .../layers/language_model/language_model.py | 2 +- 5 files changed, 48 insertions(+), 42 deletions(-) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 7b058f5f..86f1e8c0 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -14,6 +14,7 @@ from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig +from fast_llm.layers.language_model.head import _format_name from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -139,7 +140,7 @@ def forward( self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = self.mixer(hidden_states, kwargs) - hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs) + hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses) if self._debug.enabled: self._debug( @@ -171,7 +172,7 @@ def forward( hidden_states = torch.stack((fw_input, hidden_states), dim=0) return hidden_states - def activation_distillation_loss(self, hidden_states, bias, kwargs): + def activation_distillation_loss(self, hidden_states, bias, kwargs, losses): """ Maybe apply activation distillation loss and setup backward hooks """ @@ -192,7 +193,12 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs): teacher_tensor = teacher_output.detach().to(device=mixer_output.device, dtype=mixer_output.dtype) Assert.eq(teacher_tensor.shape, mixer_output.shape) # TODO: handle sequence-first? - activation_loss = torch.mean(torch.norm(mixer_output - teacher_tensor, p=2, dim=(1, 2))) + # TODO: un-scaled loss for reporting? Average loss over layers? + # L2 loss + activation_loss_factor = self._config.activation_distillation_factor + activation_loss = activation_loss_factor * torch.mean( + torch.norm(mixer_output - teacher_tensor, p=2, dim=(2)) + ) # Backward hooks hidden_states = AuxiliaryLoss.apply(hidden_states, activation_loss, 1.0) bias = AuxiliaryLoss.apply(bias, activation_loss, 1.0) if bias is not None else None @@ -202,6 +208,9 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs): activation_loss.detach() if activation_total is None else activation_total + activation_loss.detach() ) root_kwargs[BlockKwargs.activation_distillation_total] = activation_total + + if losses is not None and self._activation_distillation_loss_name in losses: + losses[self._activation_distillation_loss_name].append(activation_total.detach()) return hidden_states, bias def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: @@ -217,5 +226,21 @@ def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None self.mixer.preprocess(batch, kwargs) self.mlp.preprocess(batch, kwargs) + # TODO: add layer_index + _activation_distillation_loss_name = "activation_distillation_loss" + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return self.mixer.get_loss_definitions(count=count) + self.mlp.get_loss_definitions(count=count) + loss_definitions = [] + if self._config.activation_distillation_factor > 0.0 and self._config.distillation_model is not None: + loss_definitions.append( + LossDef( + name=self._activation_distillation_loss_name, + formatted_name=_format_name(self._activation_distillation_loss_name), + count=count, + ) + ) + return ( + loss_definitions + + self.mixer.get_loss_definitions(count=count) + + self.mlp.get_loss_definitions(count=count) + ) diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 403b204c..99331ee7 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -88,6 +88,22 @@ class DecoderBlockConfig(BlockConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + distillation_model: str | None = Field( + default=None, + desc="Name of the reference model to use for activation-level distillation.", + hint=FieldHint.feature, + ) + activation_distillation_factor: float = Field( + default=0.0, + desc="Factor to scale the activation-level distillation loss by.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + + def _validate(self) -> None: + super()._validate() + if self.activation_distillation_factor > 0.0 and self.distillation_model is None: + raise ValueError("Activation distillation requires a distillation_model.") @property def layer_class(self) -> "type[DecoderBlock]": diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 6ede28b9..25fa2d91 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -162,12 +162,6 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Factor to scale the distillation loss by when using distillation.", hint=FieldHint.feature, ) - activation_distillation_factor: float = Field( - default=0.0, - desc="Factor to scale the activation-level distillation loss by when using distillation.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) logits_scale_factor: float = Field( default=1.0, desc="Multiply output logits by scale factor.", @@ -239,8 +233,6 @@ def _validate(self) -> None: self.language_model_loss_factor = 0.0 super()._validate() assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both - if self.activation_distillation_factor > 0.0 and self.distillation_model is None: - raise ValueError("Activation distillation requires a distillation_model to be configured.") @property def max_prediction_distance(self) -> int: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 290f5014..4b0e3d10 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -18,7 +18,7 @@ from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.block.block import Block -from fast_llm.layers.block.config import BlockDimNames, BlockKwargs +from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( @@ -419,17 +419,6 @@ def _logits_cross_entropy_forward_backward( else: distillation_loss, distillation_grad = None, None - activation_loss = None - root_kwargs = kwargs.get(BlockKwargs.root, kwargs) - activation_total = root_kwargs.get(BlockKwargs.activation_distillation_total) - if activation_total is not None and self._config.activation_distillation_factor > 0.0: - activation_loss = activation_total * self._config.activation_distillation_factor - if losses is not None and self._activation_distillation_loss_name in losses: - losses[self._activation_distillation_loss_name].append(activation_loss.detach()) - # Activation targets are no longer needed past this point. - root_kwargs.pop(BlockKwargs.activation_distillation_targets, None) - root_kwargs.pop(BlockKwargs.activation_distillation_total, None) - # TODO: de-allocate earlier. del logits @@ -437,7 +426,7 @@ def _logits_cross_entropy_forward_backward( grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) # TODO: Return individual losses? - loss = _add_tensors(dpo_loss, lm_loss, distillation_loss, activation_loss) + loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) if self.training and losses is not None: if dpo_loss is not None: losses[self._dpo_loss_name].append(dpo_loss.detach()) @@ -483,13 +472,6 @@ def _distillation_loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name - @functools.cached_property - def _activation_distillation_loss_name(self) -> str: - name = "activation_distillation_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: loss_defs = [LossDef(name=self._loss_name, formatted_name=_format_name(self._loss_name), count=count)] if self._config.logit_z_loss: @@ -518,15 +500,6 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: ) ) - if self._config.activation_distillation_factor > 0.0 and self._config.distillation_model is not None: - loss_defs.append( - LossDef( - name=self._activation_distillation_loss_name, - formatted_name=_format_name(self._activation_distillation_loss_name), - count=count, - ) - ) - return loss_defs diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py index 579eb531..0b8157d7 100644 --- a/fast_llm/layers/language_model/language_model.py +++ b/fast_llm/layers/language_model/language_model.py @@ -61,7 +61,7 @@ def get_layers(self) -> list[Layer]: def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: # TODO: remove root_kwargs - activation_factor = getattr(self.head._config, "activation_distillation_factor", 0.0) + activation_factor = getattr(self.decoder.config.block, "activation_distillation_factor", 0.0) if ( activation_factor > 0.0 or BlockKwargs.activation_distillation_targets in kwargs From 8b1675203dbcbbf7a0a9e6b2141ec99669166fd0 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 14 Nov 2025 21:25:02 +0000 Subject: [PATCH 08/28] fix logging --- fast_llm/layers/decoder/block.py | 10 +++------- fast_llm/models/gpt/model.py | 7 ++++--- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 86f1e8c0..9867019c 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -196,6 +196,8 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs, losses): # TODO: un-scaled loss for reporting? Average loss over layers? # L2 loss activation_loss_factor = self._config.activation_distillation_factor + # (batch, sequence, hidden). Take the norm over hidden dim. + # TODO: handle possible padding? activation_loss = activation_loss_factor * torch.mean( torch.norm(mixer_output - teacher_tensor, p=2, dim=(2)) ) @@ -203,14 +205,8 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs, losses): hidden_states = AuxiliaryLoss.apply(hidden_states, activation_loss, 1.0) bias = AuxiliaryLoss.apply(bias, activation_loss, 1.0) if bias is not None else None # Logging - activation_total = root_kwargs.get(BlockKwargs.activation_distillation_total) - activation_total = ( - activation_loss.detach() if activation_total is None else activation_total + activation_loss.detach() - ) - root_kwargs[BlockKwargs.activation_distillation_total] = activation_total - if losses is not None and self._activation_distillation_loss_name in losses: - losses[self._activation_distillation_loss_name].append(activation_total.detach()) + losses[self._activation_distillation_loss_name].append(activation_loss.detach()) return hidden_states, bias def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 25affc49..17187d0b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -176,8 +176,9 @@ def preprocess_batch( non_blocking=True, ) - distillation_model = getattr(self._config.head, "distillation_model", None) - activation_factor = getattr(self._config.head, "activation_distillation_factor", 0.0) + # TODO: decoder doesn't necessarily have a `block` attribute + distillation_model = self._config.decoder.block.distillation_model + activation_factor = self._config.decoder.block.activation_distillation_factor reference_logits: list[dict[str, typing.Any]] | None = None reference_logits = [{} for _ in preprocessed_meta] for name, reference_model in self._reference_models.items(): @@ -190,7 +191,7 @@ def preprocess_batch( reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration, - setup_activation_storage=activation_factor > 0.0, + setup_activation_storage=activation_factor > 0.0 and distillation_model == name, ) # TODO: Do things work with >1? From efa8cf0d613d865b13c8b11412ca9ae04d1ada5f Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 14 Nov 2025 22:01:07 +0000 Subject: [PATCH 09/28] remove root kwargs --- fast_llm/layers/block/config.py | 1 - fast_llm/layers/decoder/block.py | 5 ++--- .../layers/language_model/language_model.py | 20 ------------------- 3 files changed, 2 insertions(+), 24 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 45dbe495..dfc80a47 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -37,7 +37,6 @@ class BlockKwargs: sequence_lengths = "sequence_lengths" # TODO: Belongs elsewhere? grad_output = "grad_output" - root = "_root_kwargs" activation_distillation_storage = "activation_distillation_storage" activation_distillation_targets = "activation_distillation_targets" activation_distillation_total = "activation_distillation_total" diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 9867019c..05df6e67 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -177,13 +177,12 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs, losses): Maybe apply activation distillation loss and setup backward hooks """ mixer_output = hidden_states if bias is None else hidden_states + bias - root_kwargs = kwargs.get(BlockKwargs.root, kwargs) # Teacher populates mixer activations for distillation. - activation_storage = root_kwargs.get(BlockKwargs.activation_distillation_storage) + activation_storage = kwargs.get(BlockKwargs.activation_distillation_storage) if activation_storage is not None: activation_storage[self.module_name] = mixer_output.detach() # Student gets teacher activations and computes the activation-level loss. - activation_targets = root_kwargs.get(BlockKwargs.activation_distillation_targets) + activation_targets = kwargs.get(BlockKwargs.activation_distillation_targets) if ( activation_targets is not None and self.training diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py index 0b8157d7..2e46bb57 100644 --- a/fast_llm/layers/language_model/language_model.py +++ b/fast_llm/layers/language_model/language_model.py @@ -8,7 +8,6 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.block import BlockBase -from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding @@ -60,25 +59,6 @@ def get_layers(self) -> list[Layer]: return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers() def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - # TODO: remove root_kwargs - activation_factor = getattr(self.decoder.config.block, "activation_distillation_factor", 0.0) - if ( - activation_factor > 0.0 - or BlockKwargs.activation_distillation_targets in kwargs - or BlockKwargs.activation_distillation_storage in kwargs - ): - root_state = kwargs.get(BlockKwargs.root) - if root_state is None or root_state is kwargs: - root_state = {} - kwargs[BlockKwargs.root] = root_state - if BlockKwargs.activation_distillation_targets in kwargs: - root_state[BlockKwargs.activation_distillation_targets] = kwargs[ - BlockKwargs.activation_distillation_targets - ] - if BlockKwargs.activation_distillation_storage in kwargs: - root_state[BlockKwargs.activation_distillation_storage] = kwargs[ - BlockKwargs.activation_distillation_storage - ] # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? self.embeddings.preprocess(batch, kwargs) self.decoder.preprocess(batch, kwargs) From 4cda56d69af6f226f091480584eccaa4e851e638 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 17 Nov 2025 20:43:48 +0000 Subject: [PATCH 10/28] fix mistral mlp conversion --- fast_llm/models/gpt/conversion/mistral.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index b5db3fa0..28941bc8 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -2,6 +2,7 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.models.gpt.conversion.config import MistralCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaAttentionConverter, @@ -10,6 +11,7 @@ LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, + LlamaMLPConverter, ) from fast_llm.utils import safe_merge_dicts @@ -38,8 +40,26 @@ def _check_config(cls, config: AttentionConfig) -> None: assert not config.add_linear_biases +class MistrallMLPConverter(LlamaMLPConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + config["mlp_bias"] = False + return super().import_config(config) + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + out = super().export_config(config) + del out["mlp_bias"] + return out + + @classmethod + def _check_config(cls, config: MLPConfig) -> None: + assert not config.add_linear_biases + + class MistralBlockConverter(LlamaBlockConverter): mixer_converter_class: typing.ClassVar[type[MistralAttentionConverter]] = MistralAttentionConverter + mlp_converter_class: typing.ClassVar[type[MistrallMLPConverter]] = MistrallMLPConverter class MistralDecoderConverter(LlamaDecoderConverter): From 41692e9fa2c8890231730abed464cad6e171e3bf Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 17 Nov 2025 20:49:05 +0000 Subject: [PATCH 11/28] remove duplicate from apriel conversion --- fast_llm/models/gpt/conversion/apriel.py | 25 +----------------------- 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 7550df04..ffd2522c 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -8,18 +8,12 @@ from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.ssm.config import DiscreteMamba2Config, Mamba2Config from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat -from fast_llm.models.gpt.conversion.llama import ( - LlamaMLPConverter, - get_parameter_converter, - get_weight_and_bias_converters, -) +from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters from fast_llm.models.gpt.conversion.mistral import ( MistralBaseModelConverter, - MistralBlockConverter, MistralDecoderConverter, MistralHeadConverter, MistralHuggingfaceCheckpointHandler, @@ -229,23 +223,6 @@ def get_converters( ] -class AprielMLPConverter(LlamaMLPConverter): - @classmethod - def import_config(cls, config: dict) -> dict: - config["mlp_bias"] = False - return super().import_config(config) - - @classmethod - def export_config(cls, config: MLPConfig) -> dict: - out = super().export_config(config) - del out["mlp_bias"] - return out - - -class AprielBlockConverterBase(MistralBlockConverter): - mlp_converter_class: typing.ClassVar[type[AprielMLPConverter]] = AprielMLPConverter - - class AprielDiscreteMamba2BlockConverter(AprielBlockConverterBase): mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter hf_mixer_name: typing.ClassVar[str] = "mixer" From 99c42c06d8af13df9bfeaece1e9425de4f9932fb Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 17 Nov 2025 20:56:35 +0000 Subject: [PATCH 12/28] fix --- fast_llm/models/gpt/conversion/apriel.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index ffd2522c..e16eac4d 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -14,6 +14,7 @@ from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters from fast_llm.models.gpt.conversion.mistral import ( MistralBaseModelConverter, + MistralBlockConverter, MistralDecoderConverter, MistralHeadConverter, MistralHuggingfaceCheckpointHandler, @@ -223,12 +224,12 @@ def get_converters( ] -class AprielDiscreteMamba2BlockConverter(AprielBlockConverterBase): +class AprielDiscreteMamba2BlockConverter(MistralBlockConverter): mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter hf_mixer_name: typing.ClassVar[str] = "mixer" -class AprielMamba2BlockConverter(AprielBlockConverterBase): +class AprielMamba2BlockConverter(MistralBlockConverter): mixer_converter_class: typing.ClassVar[type[AprielMamba2Converter]] = AprielMamba2Converter hf_mixer_name: typing.ClassVar[str] = "mixer" @@ -240,7 +241,7 @@ class AprielBlockConverter: DiscreteMamba2Config: "m2d", } _converter_classes = { - AttentionConfig: AprielBlockConverterBase, + AttentionConfig: MistralBlockConverter, Mamba2Config: AprielMamba2BlockConverter, DiscreteMamba2Config: AprielDiscreteMamba2BlockConverter, } From d3df7a567e5a8a02400d79bc4aee735ff10a0942 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 17 Nov 2025 21:12:37 +0000 Subject: [PATCH 13/28] move assert --- fast_llm/models/gpt/conversion/mistral.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index 28941bc8..a9a0909e 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -48,14 +48,11 @@ def import_config(cls, config: dict) -> dict: @classmethod def export_config(cls, config: MLPConfig) -> dict: + assert not config.add_linear_biases out = super().export_config(config) del out["mlp_bias"] return out - @classmethod - def _check_config(cls, config: MLPConfig) -> None: - assert not config.add_linear_biases - class MistralBlockConverter(LlamaBlockConverter): mixer_converter_class: typing.ClassVar[type[MistralAttentionConverter]] = MistralAttentionConverter From 8e04abaaa32e3a84387738e211e25bc96824502e Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 17 Nov 2025 21:58:00 +0000 Subject: [PATCH 14/28] remove tp-1 check for reference models --- fast_llm/engine/training/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 531bc206..867cca98 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -361,7 +361,6 @@ def _validate(self) -> None: # TODO: Add support. Assert.eq(self.model.distributed.pipeline_parallel, 1) # TODO: Check if these work. - Assert.eq(self.model.distributed.tensor_parallel, 1) Assert.eq(self.model.distributed.sequence_data_parallel, 1) if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() From 3ebda84fc2452c19767b9ea24a2bec4308154b58 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 20 Nov 2025 14:56:27 +0000 Subject: [PATCH 15/28] fix reduce op --- fast_llm/functional/cross_entropy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index d56dce98..c22319c1 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -151,7 +151,7 @@ def _fused_cross_entropy_forward_backward( loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.AVG, group=group) return loss, grad @@ -277,7 +277,7 @@ def _torch_reverse_kl_forward_backward( loss = (loss_per_sample * loss_mask).mean() if group is not None and target_format != TargetFormat.labels: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.AVG, group=group) if grad_output is not None: loss.backward(torch.full_like(loss, grad_output)) From f729625af48d20d9622b48f9095a678f9b1eb5b9 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 21 Nov 2025 21:33:52 +0000 Subject: [PATCH 16/28] try: loss after norm --- fast_llm/layers/decoder/block.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 05df6e67..e9f27f40 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -140,7 +140,7 @@ def forward( self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = self.mixer(hidden_states, kwargs) - hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses) + # hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses) if self._debug.enabled: self._debug( @@ -156,6 +156,7 @@ def forward( hidden_states = self.norm_2(input_) if self._debug.enabled: self._debug(hidden_states, "norm 2", kwargs[BlockKwargs.hidden_dims], kwargs) + hidden_states, _ = self.activation_distillation_loss(hidden_states, None, kwargs, losses) hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) if self._debug.enabled: self._debug( From 0effa246a47c0d2890195a189f73e52c652f9b03 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 24 Nov 2025 19:06:22 +0000 Subject: [PATCH 17/28] handle non-fixed-sequence decoder --- fast_llm/layers/block/config.py | 12 ++++++++++++ fast_llm/layers/decoder/config.py | 5 +++++ fast_llm/models/gpt/model.py | 12 +++++++----- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index dfc80a47..5f85df04 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -85,6 +85,9 @@ def get_layer( peft=peft, ) + def get_distillation_models(self) -> set[str]: + return set() + @config_class(registry=True) class BlockSequenceConfig(BlockConfig): @@ -116,6 +119,9 @@ def layer_class(self) -> "type[FixedBlockSequence]": return FixedBlockSequence + def get_distillation_models(self) -> set[str]: + return self.block.get_distillation_models() + @config_class(dynamic_type={BlockSequenceConfig: "pattern"}) class PatternBlockSequenceConfig(BlockSequenceConfig): @@ -162,3 +168,9 @@ def expanded_pattern(self) -> list[str]: def preprocessing_layers(self) -> dict[str, int]: # The index at which each block first appears. These blocks are used for preprocessing. return {name: self.expanded_pattern.index(name) for name in set(self.expanded_pattern)} + + def get_distillation_models(self) -> set[str]: + models = set() + for block in self.blocks.values(): + models.update(block.get_distillation_models()) + return models diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 99331ee7..c06f057a 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -127,3 +127,8 @@ def get_layer( peft=peft, return_input=return_input, ) + + def get_distillation_models(self) -> set[str]: + if self.distillation_model is not None and self.activation_distillation_factor > 0.0: + return {self.distillation_model} + return set() diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 17187d0b..d34ee3ed 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -176,9 +176,9 @@ def preprocess_batch( non_blocking=True, ) - # TODO: decoder doesn't necessarily have a `block` attribute - distillation_model = self._config.decoder.block.distillation_model - activation_factor = self._config.decoder.block.activation_distillation_factor + distillation_models = self._config.decoder.get_distillation_models() + # TODO: Support multiple distillation models? + assert len(distillation_models) <= 1 reference_logits: list[dict[str, typing.Any]] | None = None reference_logits = [{} for _ in preprocessed_meta] for name, reference_model in self._reference_models.items(): @@ -191,7 +191,7 @@ def preprocess_batch( reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration, - setup_activation_storage=activation_factor > 0.0 and distillation_model == name, + setup_activation_storage=name in distillation_models, ) # TODO: Do things work with >1? @@ -272,7 +272,9 @@ def preprocess_batch( if reference_logits is not None: reference_payload = reference_logits[i] kwargs.update(reference_payload) - if distillation_model is not None and activation_factor > 0.0: + # TODO: Support multiple distillation models? + assert len(distillation_models) <= 1 + for distillation_model in distillation_models: teacher_key = f"{distillation_model}_activations" if teacher_key in reference_payload: kwargs[BlockKwargs.activation_distillation_targets] = reference_payload.pop(teacher_key) From f7a0837d5ba134a3941d2599dfc174b3eb3ef62f Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 25 Nov 2025 18:21:23 +0000 Subject: [PATCH 18/28] patch creeping config params --- fast_llm/layers/block/config.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 5f85df04..fcc04981 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,4 +1,5 @@ import functools +import logging import typing import warnings @@ -8,12 +9,14 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.utils import Assert +from fast_llm.utils import Assert, log if typing.TYPE_CHECKING: from fast_llm.layers.block.block import BlockBase from fast_llm.layers.block.sequence import FixedBlockSequence, PatternBlockSequence +logger = logging.getLogger(__name__) + class BlockDimNames: # A set of common tensor dim names packed into a namespace. @@ -174,3 +177,15 @@ def get_distillation_models(self) -> set[str]: for block in self.blocks.values(): models.update(block.get_distillation_models()) return models + + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + # Patch creeping type parameters from pretrained model + # TODO: fix this + if "block" in default: + removed = default.pop("block") + log( + f"Removing 'block' from default dict in PatternBlockSequenceConfig._from_dict: {removed}", + log_fn=logger.warning, + ) + return super()._from_dict(default, strict=strict) From 6f2d5e3070f3d4188dd4aab6c63d194d1da916c1 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 25 Nov 2025 19:02:49 +0000 Subject: [PATCH 19/28] support pattern-block-sequence with compatible configs in export --- fast_llm/models/gpt/conversion/llama.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index a9249226..16582030 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -16,7 +16,7 @@ from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig from fast_llm.layers.attention.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex -from fast_llm.layers.block.config import FixedBlockSequenceConfig +from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.common.normalization.config import RMSNormalizationConfig from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.decoder.mlp.config import MLPConfig @@ -420,7 +420,18 @@ def import_config(cls, config: dict) -> dict: @classmethod def export_config(cls, config: FixedBlockSequenceConfig) -> dict: - # TODO: Support PatternBlockSequenceConfig with compatible configs. + if isinstance(config, PatternBlockSequenceConfig): + # All exported block configs must be equal + exported_block_configs = [ + safe_merge_dicts( + cls.block_converter_class.export_config(block_config), + {"num_hidden_layers": config.num_blocks}, + ) + for block_config in config.blocks.values() + ] + for other in exported_block_configs[1:]: + Assert.eq(exported_block_configs[0], other) + return exported_block_configs[0] Assert.custom(isinstance, config, FixedBlockSequenceConfig) return safe_merge_dicts( cls.block_converter_class.export_config(config.block), From 90da831c601651e4971ceb31fed4680240ef4a8d Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 25 Nov 2025 19:03:10 +0000 Subject: [PATCH 20/28] move activation-distillation loss pre-norm again --- fast_llm/layers/decoder/block.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index e9f27f40..05df6e67 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -140,7 +140,7 @@ def forward( self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = self.mixer(hidden_states, kwargs) - # hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses) + hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses) if self._debug.enabled: self._debug( @@ -156,7 +156,6 @@ def forward( hidden_states = self.norm_2(input_) if self._debug.enabled: self._debug(hidden_states, "norm 2", kwargs[BlockKwargs.hidden_dims], kwargs) - hidden_states, _ = self.activation_distillation_loss(hidden_states, None, kwargs, losses) hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) if self._debug.enabled: self._debug( From 52517190ecf61f16dedc8e68c7d305af2beece74 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 25 Nov 2025 19:32:17 +0000 Subject: [PATCH 21/28] support PatternBlockSequenceConfig in llama converter --- fast_llm/models/gpt/conversion/llama.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 16582030..1a765dc1 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -419,7 +419,7 @@ def import_config(cls, config: dict) -> dict: } @classmethod - def export_config(cls, config: FixedBlockSequenceConfig) -> dict: + def export_config(cls, config: FixedBlockSequenceConfig | PatternBlockSequenceConfig) -> dict: if isinstance(config, PatternBlockSequenceConfig): # All exported block configs must be equal exported_block_configs = [ @@ -441,15 +441,19 @@ def export_config(cls, config: FixedBlockSequenceConfig) -> dict: @classmethod def get_converters( cls, - config: FixedBlockSequenceConfig, + config: FixedBlockSequenceConfig | PatternBlockSequenceConfig, fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: + # In the case of PatternBlockSequenceConfig, compatibility was already checked in export_config + block_config = ( + config.block if isinstance(config, FixedBlockSequenceConfig) else next(iter(config.blocks.values())) + ) converters = [] for block_index in range(config.num_blocks): converters += cls.block_converter_class.get_converters( - config.block, + block_config, f"{fast_llm_prefix}.{block_index}", f"{hf_prefix}.{block_index}", drop_on_export, From d2858d6c4af6ca33015aa137c4927b8ac92af8be Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 1 Dec 2025 21:16:19 +0000 Subject: [PATCH 22/28] add distillation tests --- fast_llm/models/gpt/config.py | 1 + tests/utils/model_configs.py | 87 ++++++++++++++++++++++++++++++++++ tests/utils/run_test_script.py | 40 ++++++++++++++++ 3 files changed, 128 insertions(+) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index a901a046..2b435d40 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -171,6 +171,7 @@ def _validate(self) -> None: prediction_heads = 1 expected_names = {name for name in (head.distillation_model, head.dpo_reference_model) if name is not None} + expected_names.update(self.model.base_model.decoder.get_distillation_models()) Assert.eq(self.reference_models.keys(), expected_names) for reference_model in self.reference_models.values(): diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index c02521d7..fa1cde4a 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -537,6 +537,93 @@ def _update_and_add_testing_config( }, ) +_update_and_add_testing_config( + # Tests logit distillation. + "mistral", + "mistral_distill_logits", + updates={ + ("model", "base_model", "head", "distillation_model"): "teacher", + ("reference_models"): { + "teacher": { + "model": { + "base_model": { + "embeddings": { + "vocab_size": MODEL_TEST_VOCAB_SIZE, + }, + "decoder": { + "block": { + "mixer": { + "head_groups": 2, + }, + }, + }, + }, + }, + }, + }, + }, + megatron_args=None, + checkpoint_format=MistralCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.unimportant, + ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, + ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, + compare_factor=1.5, + # Micro-sequence mode not supported with reference models (see model.py:198) + skip_tests=("ms",), +) + +_update_and_add_testing_config( + "mistral", + "mistral_distill_activations", + updates={ + ("model", "base_model", "decoder", "block", "distillation_model"): "teacher", + ("model", "base_model", "decoder", "block", "activation_distillation_factor"): 0.05, + ("reference_models"): { + "teacher": { + "model": { + "base_model": { + "head": { + "cross_entropy_implementation": "fused", + }, + "embeddings": { + "vocab_size": MODEL_TEST_VOCAB_SIZE, + }, + "decoder": { + "block": { + "mixer": { + "head_groups": 2, + }, + }, + "num_blocks": 2, + }, + "hidden_size": 256, + }, + }, + }, + }, + }, + # Megatron doesn't support sliding windows. + megatron_args=None, + checkpoint_format=MistralCheckpointFormat, + # TODO: Add back generate as `normal` when stable. + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.unimportant, + ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, + ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, + compare_factor=2, + # Micro-sequence mode not supported with reference models (see model.py:198) + skip_tests=("ms",), +) + _update_and_add_testing_config( # Tests mixture of experts, mixtral converter. "llama", diff --git a/tests/utils/run_test_script.py b/tests/utils/run_test_script.py index 7d706ebd..e746a632 100644 --- a/tests/utils/run_test_script.py +++ b/tests/utils/run_test_script.py @@ -65,6 +65,39 @@ def run_test_script_base_path(model_testing_config, result_path, request): return result_path / "models" / model_testing_config.name +def _propagate_config_args_to_reference_models(config_args: list[str]) -> list[str]: + """ + Propagate certain model config args to reference models. + + Some config args that affect model behavior need to be applied to both + the main model and reference models to ensure compatibility. + """ + propagated_args = [] + # Patterns that should be propagated to reference models + # Only model-level configs should be propagated, not batch-level configs + # (batch is shared at the trainer level, not per-model) + propagate_patterns = [ + ("model", "base_model", "sequence_first"), + ] + + for arg in config_args: + if "=" not in arg: + continue + key, value = arg.split("=", 1) + key_tuple = tuple(key.split(".")) + + # Check if this arg should be propagated + for pattern in propagate_patterns: + if key_tuple == pattern: + # Add the reference model version of this arg + # For each reference model (we check if they exist in the config) + ref_key = f"reference_models.teacher.{key}" + propagated_args.append(f"{ref_key}={value}") + break + + return propagated_args + + def do_run_test_script_for_all_models( distributed_testing_config: DistributedTestingConfig, model_testing_config: ModelTestingConfig, @@ -73,12 +106,19 @@ def do_run_test_script_for_all_models( ): Assert.leq(distributed_testing_config.num_gpus, DistributedConfig.default_world_size) get_model_test_dataset() + + # Propagate certain config args to reference models if they exist + propagated_args = [] + if "reference_models" in str(model_testing_config.config_dict): + propagated_args = _propagate_config_args_to_reference_models(distributed_testing_config.config_args) + args = [ "fast-llm", runnable_type, model_testing_config.model_type, *model_testing_config.config_args, *distributed_testing_config.config_args, + *propagated_args, f"model.distributed.world_size={distributed_testing_config.num_gpus}", f"model.distributed.local_world_size={distributed_testing_config.num_gpus}", f"run.experiment_dir={base_path/distributed_testing_config.name}", From 280db138610eeebce91b702051f22e0144c4634c Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 1 Dec 2025 21:53:57 +0000 Subject: [PATCH 23/28] update tests --- tests/utils/model_configs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index fa1cde4a..78286708 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -570,7 +570,7 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + ModelTestingGroup.distributed: ModelTestingGroupAction.broken, }, compare_factor=1.5, # Micro-sequence mode not supported with reference models (see model.py:198) @@ -582,7 +582,7 @@ def _update_and_add_testing_config( "mistral_distill_activations", updates={ ("model", "base_model", "decoder", "block", "distillation_model"): "teacher", - ("model", "base_model", "decoder", "block", "activation_distillation_factor"): 0.05, + ("model", "base_model", "decoder", "block", "activation_distillation_factor"): 0.01, ("reference_models"): { "teacher": { "model": { @@ -617,9 +617,9 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + ModelTestingGroup.distributed: ModelTestingGroupAction.broken, }, - compare_factor=2, + compare_factor=8, # Micro-sequence mode not supported with reference models (see model.py:198) skip_tests=("ms",), ) From a46ed183eab52bb331cbf7867195553ac6fda467 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 2 Dec 2025 15:29:22 +0000 Subject: [PATCH 24/28] update tests, add reverse_kl --- tests/utils/model_configs.py | 40 ++++++++++++++++++++++++++-------- tests/utils/run_test_script.py | 1 + 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 78286708..0a976da2 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -570,19 +570,41 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.broken, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, # failing: tp2, stp2, stp2_ce4 }, compare_factor=1.5, - # Micro-sequence mode not supported with reference models (see model.py:198) - skip_tests=("ms",), + # modes not supported with reference models + skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2"), ) _update_and_add_testing_config( - "mistral", + "mistral_distill_logits", + "mistral_reverse_kl", + updates={ + ("model", "base_model", "head", "distillation_loss_implementation"): "reverse_kl", + }, + megatron_args=None, + checkpoint_format=MistralCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.unimportant, + ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, + ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, # failing: fp16, tp2, stp2, stp2_ce4 + }, + compare_factor=2, + # modes not supported with reference models + skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2"), +) + +_update_and_add_testing_config( + "mistral_distill_logits", "mistral_distill_activations", updates={ + ("model", "base_model", "head", "distillation_loss_factor"): 0.001, ("model", "base_model", "decoder", "block", "distillation_model"): "teacher", - ("model", "base_model", "decoder", "block", "activation_distillation_factor"): 0.01, + ("model", "base_model", "decoder", "block", "activation_distillation_factor"): 0.1, ("reference_models"): { "teacher": { "model": { @@ -599,7 +621,7 @@ def _update_and_add_testing_config( "head_groups": 2, }, }, - "num_blocks": 2, + "num_blocks": 2, # number of blocks and hidden-size must match student }, "hidden_size": 256, }, @@ -617,11 +639,11 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.broken, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, compare_factor=8, - # Micro-sequence mode not supported with reference models (see model.py:198) - skip_tests=("ms",), + # modes not supported with reference models + skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2", "stp2_ce4"), ) _update_and_add_testing_config( diff --git a/tests/utils/run_test_script.py b/tests/utils/run_test_script.py index e746a632..965c606f 100644 --- a/tests/utils/run_test_script.py +++ b/tests/utils/run_test_script.py @@ -78,6 +78,7 @@ def _propagate_config_args_to_reference_models(config_args: list[str]) -> list[s # (batch is shared at the trainer level, not per-model) propagate_patterns = [ ("model", "base_model", "sequence_first"), + ("model", "base_model", "embeddings", "vocab_parallel"), ] for arg in config_args: From 6e429447fddccf35a3819ac123425eec482ccdeb Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 2 Dec 2025 15:52:26 +0000 Subject: [PATCH 25/28] remove comments --- fast_llm/layers/decoder/block.py | 3 +-- fast_llm/models/gpt/model.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 85beb759..55532a16 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -180,11 +180,10 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs, losses): # Compare student mixer output with the teacher’s stored activation and accumulate the loss. teacher_tensor = teacher_output.detach().to(device=mixer_output.device, dtype=mixer_output.dtype) Assert.eq(teacher_tensor.shape, mixer_output.shape) - # TODO: handle sequence-first? # TODO: un-scaled loss for reporting? Average loss over layers? # L2 loss activation_loss_factor = self._config.activation_distillation_factor - # (batch, sequence, hidden). Take the norm over hidden dim. + # (batch, sequence, hidden) or (sequence, batch, hidden). Take the norm over hidden dim. # TODO: handle possible padding? activation_loss = activation_loss_factor * torch.mean( torch.norm(mixer_output - teacher_tensor, p=2, dim=(2)) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index dea79e41..c16e9781 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -170,7 +170,6 @@ def preprocess_batch( distillation_models = self._config.decoder.get_distillation_models() # TODO: Support multiple distillation models? assert len(distillation_models) <= 1 - reference_logits: list[dict[str, typing.Any]] | None = None reference_logits = [{} for _ in preprocessed_meta] for name, reference_model in self._reference_models.items(): reference_preprocessed_meta = [ From 4c75e1013a6bde4260b44da0e9cedf82d14a1fd0 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 2 Dec 2025 19:32:50 +0000 Subject: [PATCH 26/28] handle stp --- fast_llm/layers/decoder/block.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 55532a16..3dad5fb3 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -4,7 +4,7 @@ import torch -from fast_llm.core.distributed import set_generator +from fast_llm.core.distributed import ReduceOp, all_reduce, set_generator from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig @@ -140,13 +140,6 @@ def forward( hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses) - if self._debug.enabled: - self._debug( - hidden_states if bias is None else hidden_states + bias, - "mixer output", - kwargs[BlockKwargs.hidden_dims], - kwargs, - ) with set_generator(generator): input_ = self._bias_dropout_add(hidden_states, bias, input_) self._debug(input_, "mixer_residual", kwargs.get(BlockKwargs.hidden_dims), kwargs) @@ -177,7 +170,7 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs, losses): and self.training and (teacher_output := activation_targets.pop(self.module_name, None)) is not None ): - # Compare student mixer output with the teacher’s stored activation and accumulate the loss. + # Compare student mixer output with the teacher's stored activation and accumulate the loss. teacher_tensor = teacher_output.detach().to(device=mixer_output.device, dtype=mixer_output.dtype) Assert.eq(teacher_tensor.shape, mixer_output.shape) # TODO: un-scaled loss for reporting? Average loss over layers? @@ -185,9 +178,19 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs, losses): activation_loss_factor = self._config.activation_distillation_factor # (batch, sequence, hidden) or (sequence, batch, hidden). Take the norm over hidden dim. # TODO: handle possible padding? - activation_loss = activation_loss_factor * torch.mean( - torch.norm(mixer_output - teacher_tensor, p=2, dim=(2)) - ) + local_loss_sum = torch.sum(torch.norm(mixer_output - teacher_tensor, p=2, dim=(2))) + # mixer_output.shape is (batch, sequence, hidden) or (sequence, batch, hidden) + # In either case, dims 0 and 1 are batch and sequence + total_count = mixer_output.shape[0] * mixer_output.shape[1] + + # All-reduce across tensor-parallel group if sequence-parallel is enabled + if self._sequence_parallel and self._distributed.tensor_group is not None: + all_reduce(local_loss_sum, group=self._distributed.tensor_group, op=ReduceOp.SUM) + # Assume all ranks contribute the same count (not the case if padding) + total_count *= self._distributed.tensor_group.size() + + activation_loss = activation_loss_factor * (local_loss_sum / total_count) + # Backward hooks hidden_states = AuxiliaryLoss.apply(hidden_states, activation_loss, 1.0) bias = AuxiliaryLoss.apply(bias, activation_loss, 1.0) if bias is not None else None From 035d36cb104491a382776f383195a365b2631ec3 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 2 Dec 2025 19:43:08 +0000 Subject: [PATCH 27/28] set distillation test as broken --- tests/utils/model_configs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index b2c3cb78..27b4335f 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -571,7 +571,7 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.normal, # failing: tp2, stp2, stp2_ce4 + ModelTestingGroup.distributed: ModelTestingGroupAction.broken, # failing: tp2, stp2, stp2_ce4 }, compare_factor=1.5, # modes not supported with reference models @@ -592,7 +592,7 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.normal, # failing: fp16, tp2, stp2, stp2_ce4 + ModelTestingGroup.distributed: ModelTestingGroupAction.broken, # failing: fp16, tp2, stp2, stp2_ce4 }, compare_factor=2, # modes not supported with reference models @@ -640,7 +640,7 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + ModelTestingGroup.distributed: ModelTestingGroupAction.broken, # failing: fp16, df4, df4_sf, tp2, stp2, }, compare_factor=8, # modes not supported with reference models From e3ac422ce1e3047d7169882809d68c69e1c04421 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 2 Dec 2025 21:13:19 +0000 Subject: [PATCH 28/28] remove unused code --- fast_llm/layers/block/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index f4a439ee..bc45ea9f 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -42,7 +42,6 @@ class BlockKwargs: grad_output = "grad_output" activation_distillation_storage = "activation_distillation_storage" activation_distillation_targets = "activation_distillation_targets" - activation_distillation_total = "activation_distillation_total" iteration = "iteration" device = "device" hidden_states = "hidden_states"