Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion fast_llm/engine/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,6 @@ def _validate(self) -> None:
# TODO: Add support.
Assert.eq(self.model.distributed.pipeline_parallel, 1)
# TODO: Check if these work.
Assert.eq(self.model.distributed.tensor_parallel, 1)
Assert.eq(self.model.distributed.sequence_data_parallel, 1)
if self.run.experiment_dir is None:
assert not self.training.checkpoint.enabled()
Expand Down
6 changes: 6 additions & 0 deletions fast_llm/functional/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ def _set_activation_fn_map() -> None:
MAX_DROPLESS_BLOCK_SIZE_ROW = 128


class ReverseKLImpl(str, enum.Enum):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed, there is only one implementation.

tp = "tp"
stp = "stp"
no_tp = "no_tp"


class CrossEntropyImpl(str, enum.Enum):
auto = "auto"
torch = "torch"
Expand Down
161 changes: 90 additions & 71 deletions fast_llm/functional/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce
from fast_llm.functional.config import CrossEntropyImpl, TargetFormat
from fast_llm.functional.config import CrossEntropyImpl, ReverseKLImpl, TargetFormat
from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward
from fast_llm.utils import Assert

Expand Down Expand Up @@ -144,14 +144,20 @@ def _fused_cross_entropy_forward_backward(
all_reduce(predicted_logits, op=ReduceOp.SUM, group=group)
else:
predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True)
if group is not None:
# this is needed because on each rank we calculate log Z - sum_i t_i * z_i, z_i is logit.
# then we average: 1/K sum_ranks (log Z - sum_i t_i * z_i)
# = log Z - 1/K sum_ranks (sum_i t_i * z_i)
# but sum_ranks (sum_i t_i * z_i) = sum_i t_i * z_i (over all vocab)
predicted_logits = predicted_logits * group.size()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks wrong, see previous comment. The previous version was tested and confirmed to work.


per_sample_loss = sum_exp_logits.log() - predicted_logits
if loss_mask is not None:
per_sample_loss = per_sample_loss * loss_mask

loss = per_sample_loss.mean()
if target_format != TargetFormat.labels and group is not None:
all_reduce(loss, op=ReduceOp.MEAN, group=group)
all_reduce(loss, op=ReduceOp.AVG, group=group)

return loss, grad

Expand Down Expand Up @@ -213,71 +219,72 @@ def cross_entropy_forward_backward(
)


def distributed_log_softmax(
logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1
):
logits_norm, _, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group=group, dim=dim)

return logits_norm - sum_exp_logits.log() # log_softmax


def _torch_reverse_kl_forward_backward(
logits: torch.Tensor,
target: torch.Tensor,
loss_mask: torch.Tensor | None,
grad_output: float | None,
logits_scale_factor: float,
target_format: TargetFormat,
group: ProcessGroup | None = None,
logits_scale_factor: float = 1.0,
teacher_softmax_temperature: float = 1.0,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Reverse KL using PyTorch's native kl_div function.
Much simpler and more reliable than custom implementation!
This is used for TP version where we split accross vocab dimantion. KL is additive over partitions of the vocab.

Takes:
logits: [BxS, V] or [B, S, V]
target: [BxS, V] or [B, S, V] (logits format)
loss_mask: [BxS] or [B, S] or None
...
"""
Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format")
Assert.eq(
teacher_softmax_temperature,
1,
msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL",
)
Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL")
Assert.eq(target.shape, logits.shape)
assert target.dtype.is_floating_point, target.dtype
if loss_mask is not None:
Assert.eq(loss_mask.shape, logits.shape[:-1])

# Compute log probabilities - let _fused_softmax handle scaling internally
# teacher_probs = _fused_softmax(target, logits_scale_factor * (1 / teacher_softmax_temperature), group)
# # teacher_log_probs = torch.log(teacher_probs + 1e-8) # log(p)
# teacher_probs = torch.clamp(teacher_probs, min=1e-7) # or even 1e-6
# teacher_log_probs = torch.log(teacher_probs)

# Scale target logits more carefully
scaled_target = target * (logits_scale_factor / teacher_softmax_temperature)

# Clamp to prevent extreme values before log_softmax
scaled_target = torch.clamp(scaled_target, min=-50, max=50)
teacher_log_probs = torch.log_softmax(scaled_target, dim=-1)

# For reverse KL: KL(q||p) = Ξ£ q * log(q/p) = Ξ£ q * (log(q) - log(p))
# Use kl_div with: input=log(p), target=q, log_target=False
# This gives: Ξ£ q * (log(q) - log(p)) = exactly what we want!

# Compute log probabilities
teacher_log_probs = distributed_log_softmax(target.float(), group=group)
# batch_size = logits.shape[0]
with torch.enable_grad():
logits_ = logits.detach().requires_grad_(grad_output is not None)

# Use log_softmax for consistency instead of _fused_softmax
scaled_logits = logits_ * logits_scale_factor
scaled_logits = torch.clamp(scaled_logits, min=-50, max=50)
student_log_probs = torch.log_softmax(scaled_logits, dim=-1)

# Convert to probabilities for kl_div
# student_probs_ = torch.exp(student_log_probs)
logits_ = logits.float().detach().requires_grad_(grad_output is not None)
student_log_probs = distributed_log_softmax(logits_, group=group)

# Reverse KL: input=teacher_log_probs, target=student_probs
if loss_mask is None:
loss = torch.nn.functional.kl_div(
teacher_log_probs, # input = log(p)
student_log_probs, # target = log(q)
reduction="batchmean",
log_target=True,
)
loss_terms = torch.nn.functional.kl_div(
teacher_log_probs, # input = log(p)
student_log_probs, # target = log(q)
reduction="none",
log_target=True,
).sum(dim=-1)
if loss_mask is not None:
# loss mask is the same on all ranks for TP over vocab.
valid = loss_mask.to(loss_terms.dtype)
loss_terms = loss_terms * valid
valid_tokens = torch.tensor(valid.sum(), device=loss_terms.device, dtype=loss_terms.dtype)
else:
# Apply loss mask - this requires some reshaping
loss_per_sample = torch.nn.functional.kl_div(
teacher_log_probs, student_log_probs, reduction="none", log_target=True
).sum(dim=-1)
loss = (loss_per_sample * loss_mask).mean()
valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype))
loss = loss_terms.sum() # sums over batch and seq. len.

if group is not None and target_format != TargetFormat.labels:
all_reduce(loss, op=ReduceOp.MEAN, group=group)
if group is not None:
all_reduce(loss, op=ReduceOp.SUM, group=group)
loss /= valid_tokens

if grad_output is not None:
loss.backward(torch.full_like(loss, grad_output))
Expand All @@ -288,6 +295,12 @@ def _torch_reverse_kl_forward_backward(
return loss.detach_(), grad


REVERSE_KL_IMPLEMENTATIONS = {
ReverseKLImpl.no_tp: _torch_reverse_kl_forward_backward,
ReverseKLImpl.tp: _torch_reverse_kl_forward_backward,
}


def reverse_kl_forward_backward(
logits: torch.Tensor,
target: torch.Tensor,
Expand All @@ -297,6 +310,9 @@ def reverse_kl_forward_backward(
logits_scale_factor: float = 1.0,
teacher_softmax_temperature: float = 1.0,
target_format: TargetFormat = TargetFormat.labels,
sequence_parallel_logits: bool = False,
group_size: int = None,
vocab_size: int = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher).
Expand All @@ -309,37 +325,40 @@ def reverse_kl_forward_backward(
- Standard CE: KL(p||q) = mode-covering (spreads mass broadly)
- Reverse KL: KL(q||p) = mode-seeking (focuses on target modes)

Args:
logits: Model predictions [batch_size, ..., vocab_size]
target: Target distribution or labels
loss_mask: Optional mask for loss computation
grad_output: Gradient output scale factor
group: Process group for tensor parallelism
logits_scale_factor: Temperature scaling factor (1/T)
target_format: Format of target (labels or logits)
Takes:
logits: [BxS, V] or [B, S, V], where V is local vocab size
target: [BxS, V] or [B, S, V] (logits format)
loss_mask: [BxS] or [B, S] or None
...

Returns:
loss: Reverse KL divergence loss
grad: Gradients w.r.t. logits

Example usage:
# Replace standard cross-entropy with reverse KL
# loss, grad = cross_entropy_forward_backward(logits, target, ...)
loss, grad = reverse_kl_forward_backward(logits, target,
loss_mask=None,
grad_output=1.0,
logits_scale_factor=1.0,
target_format=TargetFormat.labels)
"""
if target_format == TargetFormat.labels:
Assert.eq(target.shape, logits.shape[:-1])
Assert.eq(target.dtype, torch.int64)

if logits.shape[-1] != vocab_size:
reverse_kl_impl = ReverseKLImpl.tp
elif sequence_parallel_logits:
# TODO: see hybrid dev branch where it is implemented
raise NotImplementedError("Sequence-parallel reverse KL is not implemented yet, set vocab_parallel true")
else:
Assert.eq(target.shape, logits.shape)
assert target.dtype.is_floating_point, target.dtype
if loss_mask is not None:
Assert.eq(loss_mask.shape, logits.shape[:-1])
reverse_kl_impl = ReverseKLImpl.no_tp

Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format")
Assert.eq(target.shape, logits.shape)
assert target.dtype.is_floating_point, target.dtype
if loss_mask is not None:
Assert.eq(loss_mask.shape, logits.shape[:-1])

# TODO: implement fused?
return _torch_reverse_kl_forward_backward(
logits, target, loss_mask, grad_output, logits_scale_factor, target_format, group, teacher_softmax_temperature
distillation_loss, distillation_grad = REVERSE_KL_IMPLEMENTATIONS[reverse_kl_impl](
logits=logits,
target=target,
loss_mask=loss_mask,
grad_output=grad_output,
logits_scale_factor=logits_scale_factor,
target_format=target_format,
teacher_softmax_temperature=teacher_softmax_temperature,
group=group,
)
return distillation_loss, distillation_grad
4 changes: 4 additions & 0 deletions fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,11 @@ def _logits_cross_entropy_forward_backward(
target_format=(
TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits
),
sequence_parallel_logits=self._sequence_parallel_logits,
group_size=self._distributed_config.tensor_parallel,
vocab_size=self._vocab_dim.global_size,
)

elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy:
distillation_loss, distillation_grad = cross_entropy_forward_backward(
logits.flatten(0, -2),
Expand Down
8 changes: 4 additions & 4 deletions fast_llm/models/gpt/conversion/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,10 @@ def import_config(cls, config: dict) -> dict:
elif rope_type == "yarn":
rotary_config.update(
{
"attention_factor": config["attention_factor"],
"beta_fast": config["beta_fast"],
"beta_slow": config["beta_slow"],
"original_context_length": config["original_max_position_embeddings"],
"attention_factor": config["rope_scaling"]["attention_factor"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be the same for llama 3 above?

"beta_fast": config["rope_scaling"]["beta_fast"],
"beta_slow": config["rope_scaling"]["beta_slow"],
"original_context_length": config["rope_scaling"]["original_max_position_embeddings"],
}
)
else:
Expand Down
Loading