Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9fa4c46
activation distillation: first draft
RaymondLi0 Nov 12, 2025
11708ff
fix kwargs
RaymondLi0 Nov 12, 2025
9437310
remove count, add auxiliaryLoss hook
RaymondLi0 Nov 13, 2025
d3ac964
fix auxiliary loss
RaymondLi0 Nov 13, 2025
56fc8db
wrap in method
RaymondLi0 Nov 13, 2025
5d75f01
fixes
RaymondLi0 Nov 13, 2025
f1bfca9
move activation distillation loss reporting to decoder block
RaymondLi0 Nov 14, 2025
8b16752
fix logging
RaymondLi0 Nov 14, 2025
efa8cf0
remove root kwargs
RaymondLi0 Nov 14, 2025
4cda56d
fix mistral mlp conversion
RaymondLi0 Nov 17, 2025
9ca2347
Merge branch 'raymond/fix_mistral_conv' into raymond/activation_disil…
RaymondLi0 Nov 17, 2025
41692e9
remove duplicate from apriel conversion
RaymondLi0 Nov 17, 2025
99c42c0
fix
RaymondLi0 Nov 17, 2025
d3df7a5
move assert
RaymondLi0 Nov 17, 2025
f2f097e
Merge branch 'raymond/fix_mistral_conv' into raymond/activation_disil…
RaymondLi0 Nov 17, 2025
8e04aba
remove tp-1 check for reference models
RaymondLi0 Nov 17, 2025
3ebda84
fix reduce op
RaymondLi0 Nov 20, 2025
6a8732f
Merge branch 'raymond/fix_distill_tp' into raymond/activation_disilla…
RaymondLi0 Nov 20, 2025
f729625
try: loss after norm
RaymondLi0 Nov 21, 2025
0effa24
handle non-fixed-sequence decoder
RaymondLi0 Nov 24, 2025
f7a0837
patch creeping config params
RaymondLi0 Nov 25, 2025
6f2d5e3
support pattern-block-sequence with compatible configs in export
RaymondLi0 Nov 25, 2025
90da831
move activation-distillation loss pre-norm again
RaymondLi0 Nov 25, 2025
5251719
support PatternBlockSequenceConfig in llama converter
RaymondLi0 Nov 25, 2025
d2858d6
add distillation tests
RaymondLi0 Dec 1, 2025
280db13
update tests
RaymondLi0 Dec 1, 2025
a46ed18
update tests, add reverse_kl
RaymondLi0 Dec 2, 2025
01b5530
Merge branch 'main' into raymond/activation_disillation
RaymondLi0 Dec 2, 2025
6e42944
remove comments
RaymondLi0 Dec 2, 2025
4c75e10
handle stp
RaymondLi0 Dec 2, 2025
035d36c
set distillation test as broken
RaymondLi0 Dec 2, 2025
e3ac422
remove unused code
RaymondLi0 Dec 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def preprocess_batch(
phase: PhaseType,
iteration: int,
metrics: dict | None = None,
setup_activation_storage: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed, you can communicate through preprocessed meta kwargs.

) -> list[tuple[torch.Tensor, dict]]:
# TODO Move batch splitting elsewhere, align interface with LayerBase
pass
Expand Down
1 change: 0 additions & 1 deletion fast_llm/engine/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,6 @@ def _validate(self) -> None:
# TODO: Add support.
Assert.eq(self.model.distributed.pipeline_parallel, 1)
# TODO: Check if these work.
Assert.eq(self.model.distributed.tensor_parallel, 1)
Assert.eq(self.model.distributed.sequence_data_parallel, 1)
if self.run.experiment_dir is None:
assert not self.training.checkpoint.enabled()
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/functional/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _fused_cross_entropy_forward_backward(

loss = per_sample_loss.mean()
if target_format != TargetFormat.labels and group is not None:
all_reduce(loss, op=ReduceOp.MEAN, group=group)
all_reduce(loss, op=ReduceOp.AVG, group=group)

return loss, grad

Expand Down Expand Up @@ -277,7 +277,7 @@ def _torch_reverse_kl_forward_backward(
loss = (loss_per_sample * loss_mask).mean()

if group is not None and target_format != TargetFormat.labels:
all_reduce(loss, op=ReduceOp.MEAN, group=group)
all_reduce(loss, op=ReduceOp.AVG, group=group)

if grad_output is not None:
loss.backward(torch.full_like(loss, grad_output))
Expand Down
31 changes: 30 additions & 1 deletion fast_llm/layers/block/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import logging
import typing
import warnings

Expand All @@ -8,12 +9,14 @@
from fast_llm.engine.config_utils.tensor_dim import TensorDim
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.utils import Assert
from fast_llm.utils import Assert, log

if typing.TYPE_CHECKING:
from fast_llm.layers.block.block import BlockBase
from fast_llm.layers.block.sequence import FixedBlockSequence, PatternBlockSequence

logger = logging.getLogger(__name__)


class BlockDimNames:
# A set of common tensor dim names packed into a namespace.
Expand All @@ -37,6 +40,8 @@ class BlockKwargs:
sequence_lengths = "sequence_lengths"
# TODO: Belongs elsewhere?
grad_output = "grad_output"
activation_distillation_storage = "activation_distillation_storage"
activation_distillation_targets = "activation_distillation_targets"
iteration = "iteration"
device = "device"
hidden_states = "hidden_states"
Expand Down Expand Up @@ -87,6 +92,9 @@ def get_layer(
peft=peft,
)

def get_distillation_models(self) -> set[str]:
return set()


@config_class(registry=True)
class BlockSequenceConfig(BlockConfig):
Expand Down Expand Up @@ -118,6 +126,9 @@ def layer_class(self) -> "type[FixedBlockSequence]":

return FixedBlockSequence

def get_distillation_models(self) -> set[str]:
return self.block.get_distillation_models()


@config_class(dynamic_type={BlockSequenceConfig: "pattern"})
class PatternBlockSequenceConfig(BlockSequenceConfig):
Expand Down Expand Up @@ -164,3 +175,21 @@ def expanded_pattern(self) -> list[str]:
def preprocessing_layers(self) -> dict[str, int]:
# The index at which each block first appears. These blocks are used for preprocessing.
return {name: self.expanded_pattern.index(name) for name in set(self.expanded_pattern)}

def get_distillation_models(self) -> set[str]:
models = set()
for block in self.blocks.values():
models.update(block.get_distillation_models())
return models

@classmethod
def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self:
# Patch creeping type parameters from pretrained model
# TODO: fix this
if "block" in default:
removed = default.pop("block")
log(
f"Removing 'block' from default dict in PatternBlockSequenceConfig._from_dict: {removed}",
log_fn=logger.warning,
)
return super()._from_dict(default, strict=strict)
71 changes: 69 additions & 2 deletions fast_llm/layers/decoder/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@

import torch

from fast_llm.core.distributed import set_generator
from fast_llm.core.distributed import ReduceOp, all_reduce, set_generator
from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig
from fast_llm.engine.config_utils.tensor_dim import TensorDim
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.layers.block.block import Block
from fast_llm.layers.block.config import BlockKwargs
from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig
from fast_llm.layers.language_model.head import _format_name
from fast_llm.tensor import TensorMeta
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -134,6 +137,9 @@ def forward(
hidden_states = self.norm_1(input_)
self._debug(hidden_states, "norm_1", kwargs.get(BlockKwargs.hidden_dims), kwargs)
hidden_states, bias = self.mixer(hidden_states, kwargs)

hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses)

with set_generator(generator):
input_ = self._bias_dropout_add(hidden_states, bias, input_)
self._debug(input_, "mixer_residual", kwargs.get(BlockKwargs.hidden_dims), kwargs)
Expand All @@ -148,6 +154,51 @@ def forward(
hidden_states = torch.stack((fw_input, hidden_states), dim=0)
return hidden_states

def activation_distillation_loss(self, hidden_states, bias, kwargs, losses):
"""
Maybe apply activation distillation loss and setup backward hooks
"""
mixer_output = hidden_states if bias is None else hidden_states + bias
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should only be evaluated if needed.

# Teacher populates mixer activations for distillation.
activation_storage = kwargs.get(BlockKwargs.activation_distillation_storage)
if activation_storage is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using the new _debug / output_hidden_states interface instead? It does the exact same thing.

activation_storage[self.module_name] = mixer_output.detach()
# Student gets teacher activations and computes the activation-level loss.
activation_targets = kwargs.get(BlockKwargs.activation_distillation_targets)
if (
activation_targets is not None
and self.training
and (teacher_output := activation_targets.pop(self.module_name, None)) is not None
):
# Compare student mixer output with the teacher's stored activation and accumulate the loss.
teacher_tensor = teacher_output.detach().to(device=mixer_output.device, dtype=mixer_output.dtype)
Assert.eq(teacher_tensor.shape, mixer_output.shape)
# TODO: un-scaled loss for reporting? Average loss over layers?
# L2 loss
activation_loss_factor = self._config.activation_distillation_factor
# (batch, sequence, hidden) or (sequence, batch, hidden). Take the norm over hidden dim.
# TODO: handle possible padding?
local_loss_sum = torch.sum(torch.norm(mixer_output - teacher_tensor, p=2, dim=(2)))
# mixer_output.shape is (batch, sequence, hidden) or (sequence, batch, hidden)
# In either case, dims 0 and 1 are batch and sequence
total_count = mixer_output.shape[0] * mixer_output.shape[1]

# All-reduce across tensor-parallel group if sequence-parallel is enabled
if self._sequence_parallel and self._distributed.tensor_group is not None:
all_reduce(local_loss_sum, group=self._distributed.tensor_group, op=ReduceOp.SUM)
# Assume all ranks contribute the same count (not the case if padding)
total_count *= self._distributed.tensor_group.size()

activation_loss = activation_loss_factor * (local_loss_sum / total_count)

# Backward hooks
hidden_states = AuxiliaryLoss.apply(hidden_states, activation_loss, 1.0)
bias = AuxiliaryLoss.apply(bias, activation_loss, 1.0) if bias is not None else None
# Logging
if losses is not None and self._activation_distillation_loss_name in losses:
losses[self._activation_distillation_loss_name].append(activation_loss.detach())
return hidden_states, bias

def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int:
# TODO: Add marginal compute? (normalization, bias_dropout_add)
return sum(
Expand All @@ -161,5 +212,21 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None:
self.mixer.preprocess(kwargs)
self.mlp.preprocess(kwargs)

# TODO: add layer_index
_activation_distillation_loss_name = "activation_distillation_loss"

def get_loss_definitions(self, count: int = 1) -> list[LossDef]:
return self.mixer.get_loss_definitions(count=count) + self.mlp.get_loss_definitions(count=count)
loss_definitions = []
if self._config.activation_distillation_factor > 0.0 and self._config.distillation_model is not None:
loss_definitions.append(
LossDef(
name=self._activation_distillation_loss_name,
formatted_name=_format_name(self._activation_distillation_loss_name),
count=count,
)
)
return (
loss_definitions
+ self.mixer.get_loss_definitions(count=count)
+ self.mlp.get_loss_definitions(count=count)
)
21 changes: 21 additions & 0 deletions fast_llm/layers/decoder/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,22 @@ class DecoderBlockConfig(BlockConfig):
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)
distillation_model: str | None = Field(
default=None,
desc="Name of the reference model to use for activation-level distillation.",
hint=FieldHint.feature,
)
activation_distillation_factor: float = Field(
default=0.0,
desc="Factor to scale the activation-level distillation loss by.",
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)

def _validate(self) -> None:
super()._validate()
if self.activation_distillation_factor > 0.0 and self.distillation_model is None:
raise ValueError("Activation distillation requires a distillation_model.")

@property
def layer_class(self) -> "type[DecoderBlock]":
Expand All @@ -223,3 +239,8 @@ def get_layer(
peft=peft,
return_input=return_input,
)

def get_distillation_models(self) -> set[str]:
if self.distillation_model is not None and self.activation_distillation_factor > 0.0:
return {self.distillation_model}
return set()
1 change: 1 addition & 0 deletions fast_llm/models/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def _validate(self) -> None:
prediction_heads = 1

expected_names = {name for name in (head.distillation_model, head.dpo_reference_model) if name is not None}
expected_names.update(self.model.base_model.decoder.get_distillation_models())
Assert.eq(self.reference_models.keys(), expected_names)

for reference_model in self.reference_models.values():
Expand Down
25 changes: 20 additions & 5 deletions fast_llm/models/gpt/conversion/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from fast_llm.layers.attention.config import AttentionConfig
from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig
from fast_llm.layers.attention.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex
from fast_llm.layers.block.config import FixedBlockSequenceConfig
from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig
from fast_llm.layers.common.normalization.config import RMSNormalizationConfig
from fast_llm.layers.decoder.config import DecoderBlockConfig
from fast_llm.layers.decoder.mlp.config import MLPConfig
Expand Down Expand Up @@ -419,8 +419,19 @@ def import_config(cls, config: dict) -> dict:
}

@classmethod
def export_config(cls, config: FixedBlockSequenceConfig) -> dict:
# TODO: Support PatternBlockSequenceConfig with compatible configs.
def export_config(cls, config: FixedBlockSequenceConfig | PatternBlockSequenceConfig) -> dict:
if isinstance(config, PatternBlockSequenceConfig):
# All exported block configs must be equal
exported_block_configs = [
safe_merge_dicts(
cls.block_converter_class.export_config(block_config),
{"num_hidden_layers": config.num_blocks},
)
for block_config in config.blocks.values()
]
for other in exported_block_configs[1:]:
Assert.eq(exported_block_configs[0], other)
return exported_block_configs[0]
Assert.custom(isinstance, config, FixedBlockSequenceConfig)
return safe_merge_dicts(
cls.block_converter_class.export_config(config.block),
Expand All @@ -430,15 +441,19 @@ def export_config(cls, config: FixedBlockSequenceConfig) -> dict:
@classmethod
def get_converters(
cls,
config: FixedBlockSequenceConfig,
config: FixedBlockSequenceConfig | PatternBlockSequenceConfig,
fast_llm_prefix: str,
hf_prefix: str,
drop_on_export: bool = False,
) -> list[WeightConverter]:
# In the case of PatternBlockSequenceConfig, compatibility was already checked in export_config
block_config = (
config.block if isinstance(config, FixedBlockSequenceConfig) else next(iter(config.blocks.values()))
)
converters = []
for block_index in range(config.num_blocks):
converters += cls.block_converter_class.get_converters(
config.block,
block_config,
f"{fast_llm_prefix}.{block_index}",
f"{hf_prefix}.{block_index}",
drop_on_export,
Expand Down
25 changes: 24 additions & 1 deletion fast_llm/models/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def preprocess_batch(
phase: PhaseType,
iteration: int,
metrics: dict | None = None,
setup_activation_storage: bool = False,
) -> list[tuple[torch.Tensor, dict]]:
# TODO Move batch splitting elsewhere, align interface with LayerBase
assert self._is_setup
Expand All @@ -166,21 +167,33 @@ def preprocess_batch(
if preprocessed_meta is None:
preprocessed_meta = self.preprocess_meta(batch, phase)

distillation_models = self._config.decoder.get_distillation_models()
# TODO: Support multiple distillation models?
assert len(distillation_models) <= 1
reference_logits = [{} for _ in preprocessed_meta]
for name, reference_model in self._reference_models.items():
reference_preprocessed_meta = [
(tokens_meta, kwargs_meta["reference_models"][name]) for tokens_meta, kwargs_meta in preprocessed_meta
]

reference_batch = reference_model.fast_llm_model.base_model.preprocess_batch(
batch, reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration
batch,
reference_preprocessed_meta,
phase=PhaseType.inference,
iteration=iteration,
setup_activation_storage=name in distillation_models,
)

# TODO: Do things work with >1?
Assert.eq(len(reference_batch), len(preprocessed_meta), 1)
for i, (reference_tokens, reference_kwargs) in enumerate(reference_batch):
reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration)
reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"]
if BlockKwargs.activation_distillation_storage in reference_kwargs:
reference_logits[i][f"{name}_activations"] = reference_kwargs[
BlockKwargs.activation_distillation_storage
]
del reference_kwargs[BlockKwargs.activation_distillation_storage]

preprocessed = []
presents = None
Expand All @@ -205,6 +218,16 @@ def preprocess_batch(
**reference_logits[i],
}

if setup_activation_storage:
# Teacher: add storage for activations
kwargs.setdefault(BlockKwargs.activation_distillation_storage, {})
# Add activation-distillation targets
assert len(distillation_models) <= 1
for distillation_model in distillation_models:
teacher_key = f"{distillation_model}_activations"
if teacher_key in reference_logits[i]:
kwargs[BlockKwargs.activation_distillation_targets] = reference_logits[i].pop(teacher_key)

if phase != PhaseType.inference:
labels_begin = tokens_begin + 1
labels_end = tokens_end + self._config.head.max_prediction_distance
Expand Down
Loading
Loading