From 1a163003f9a450b0a6b88ed52dc05c3ad5d2543c Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 2 Dec 2025 16:47:25 +0000 Subject: [PATCH 01/10] wip --- fast_llm/engine/training/config.py | 4 +- fast_llm/functional/config.py | 6 + fast_llm/functional/cross_entropy.py | 155 ++++++++++++++++--------- fast_llm/layers/language_model/head.py | 4 + 4 files changed, 109 insertions(+), 60 deletions(-) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 531bc206..b836516a 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -361,8 +361,8 @@ def _validate(self) -> None: # TODO: Add support. Assert.eq(self.model.distributed.pipeline_parallel, 1) # TODO: Check if these work. - Assert.eq(self.model.distributed.tensor_parallel, 1) - Assert.eq(self.model.distributed.sequence_data_parallel, 1) + # Assert.eq(self.model.distributed.tensor_parallel, 1) + # Assert.eq(self.model.distributed.sequence_data_parallel, 1) if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() for reference_model in self.reference_models.values(): diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 68419384..012a04dd 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -89,6 +89,12 @@ def _set_activation_fn_map() -> None: MAX_DROPLESS_BLOCK_SIZE_ROW = 128 +class ReverseKLImpl(str, enum.Enum): + tp = "tp" + stp = "stp" + no_tp = "no_tp" + + class CrossEntropyImpl(str, enum.Enum): auto = "auto" torch = "torch" diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index d56dce98..0ecdddcc 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -1,7 +1,7 @@ import torch from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce -from fast_llm.functional.config import CrossEntropyImpl, TargetFormat +from fast_llm.functional.config import CrossEntropyImpl, ReverseKLImpl, TargetFormat from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward from fast_llm.utils import Assert @@ -145,13 +145,19 @@ def _fused_cross_entropy_forward_backward( else: predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) + # shouldn't the predicted_logits be scaled by the number of ranks so that the average loss is correct? i.e. + # i.e. on each rank we calculate log Z - sum_i t_i * z_i, z_i is logit. + # then we average: 1/K sum_ranks (log Z - sum_i t_i * z_i) + # = log Z - 1/K sum_ranks (sum_i t_i * z_i) + # but sum_ranks (sum_i t_i * z_i) = sum_i t_i * z_i (over all vocab), so we need to divide predicted_logits by K to match? + per_sample_loss = sum_exp_logits.log() - predicted_logits if loss_mask is not None: per_sample_loss = per_sample_loss * loss_mask loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.AVG, group=group) return loss, grad @@ -213,71 +219,75 @@ def cross_entropy_forward_backward( ) +def distributed_log_softmax( + logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 +): + # logits = logits.float() + # local_max = logits.max(dim=dim, keepdim=True)[0] + # all_reduce(local_max, op=ReduceOp.MAX, group=group) + + # logits_shifted = logits - local_max + # exp_logits = torch.exp(logits_shifted) + # sum_exp = exp_logits.sum(dim=dim, keepdim=True) + # all_reduce(sum_exp, op=ReduceOp.SUM, group=group) + # return logits_shifted - sum_exp.log() # log_softmax + logits_norm, _, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group=group, dim=dim) + + return logits_norm - sum_exp_logits.log() # log_softmax + + def _torch_reverse_kl_forward_backward( logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None, grad_output: float | None, - logits_scale_factor: float, target_format: TargetFormat, group: ProcessGroup | None = None, + logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. - Much simpler and more reliable than custom implementation! + This is used for TP version where we split accross vocab dimantion. KL is additive over partitions of the vocab. + + This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case. + In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss. """ + Assert.eq( + teacher_softmax_temperature, + 1, + msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL", + ) + Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL") + # TODO: merge into single function _torch_reverse_kl_forward_backward Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") Assert.eq(target.shape, logits.shape) assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) - # Compute log probabilities - let _fused_softmax handle scaling internally - # teacher_probs = _fused_softmax(target, logits_scale_factor * (1 / teacher_softmax_temperature), group) - # # teacher_log_probs = torch.log(teacher_probs + 1e-8) # log(p) - # teacher_probs = torch.clamp(teacher_probs, min=1e-7) # or even 1e-6 - # teacher_log_probs = torch.log(teacher_probs) - - # Scale target logits more carefully - scaled_target = target * (logits_scale_factor / teacher_softmax_temperature) - - # Clamp to prevent extreme values before log_softmax - scaled_target = torch.clamp(scaled_target, min=-50, max=50) - teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) - - # For reverse KL: KL(q||p) = Σ q * log(q/p) = Σ q * (log(q) - log(p)) - # Use kl_div with: input=log(p), target=q, log_target=False - # This gives: Σ q * (log(q) - log(p)) = exactly what we want! - + # Compute log probabilities + teacher_log_probs = distributed_log_softmax(target.float(), group=group) + batch_size = logits.shape[0] with torch.enable_grad(): - logits_ = logits.detach().requires_grad_(grad_output is not None) - - # Use log_softmax for consistency instead of _fused_softmax - scaled_logits = logits_ * logits_scale_factor - scaled_logits = torch.clamp(scaled_logits, min=-50, max=50) - student_log_probs = torch.log_softmax(scaled_logits, dim=-1) - - # Convert to probabilities for kl_div - # student_probs_ = torch.exp(student_log_probs) + logits_ = logits.float().detach().requires_grad_(grad_output is not None) + student_log_probs = distributed_log_softmax(logits_, group=group) # Reverse KL: input=teacher_log_probs, target=student_probs if loss_mask is None: loss = torch.nn.functional.kl_div( teacher_log_probs, # input = log(p) student_log_probs, # target = log(q) - reduction="batchmean", + reduction="sum", log_target=True, ) else: - # Apply loss mask - this requires some reshaping - loss_per_sample = torch.nn.functional.kl_div( - teacher_log_probs, student_log_probs, reduction="none", log_target=True - ).sum(dim=-1) - loss = (loss_per_sample * loss_mask).mean() + raise NotImplementedError("Loss mask not implemented with TP for reverse KL.") - if group is not None and target_format != TargetFormat.labels: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + if group is not None: + all_reduce(loss, op=ReduceOp.SUM, group=group) + loss /= batch_size if grad_output is not None: loss.backward(torch.full_like(loss, grad_output)) @@ -288,6 +298,13 @@ def _torch_reverse_kl_forward_backward( return loss.detach_(), grad +REVERSE_KL_IMPLEMENTATIONS = { + ReverseKLImpl.no_tp: _torch_reverse_kl_forward_backward, + ReverseKLImpl.tp: _torch_reverse_kl_forward_backward, + ReverseKLImpl.stp: _torch_reverse_kl_forward_backward, +} + + def reverse_kl_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -297,6 +314,9 @@ def reverse_kl_forward_backward( logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, + sequence_parallel_logits: bool = False, + group_size: int = None, + vocab_size: int = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). @@ -309,28 +329,33 @@ def reverse_kl_forward_backward( - Standard CE: KL(p||q) = mode-covering (spreads mass broadly) - Reverse KL: KL(q||p) = mode-seeking (focuses on target modes) - Args: - logits: Model predictions [batch_size, ..., vocab_size] - target: Target distribution or labels - loss_mask: Optional mask for loss computation - grad_output: Gradient output scale factor - group: Process group for tensor parallelism - logits_scale_factor: Temperature scaling factor (1/T) - target_format: Format of target (labels or logits) - Returns: loss: Reverse KL divergence loss grad: Gradients w.r.t. logits - - Example usage: - # Replace standard cross-entropy with reverse KL - # loss, grad = cross_entropy_forward_backward(logits, target, ...) - loss, grad = reverse_kl_forward_backward(logits, target, - loss_mask=None, - grad_output=1.0, - logits_scale_factor=1.0, - target_format=TargetFormat.labels) """ + + total_valid_tokens = logits.shape[0] + if logits.shape[-1] != vocab_size: + reverse_kl_impl = ReverseKLImpl.tp + assert loss_mask is None, "Loss mask is not implemented for vocab-TP in reverse KL yet" + elif sequence_parallel_logits: + # the case when vocab_parallel is False and sequence_parallel is True + # grad_output already reflects scaling 1/ number of ranks (group_size), see _forward_backward + reverse_kl_impl = ReverseKLImpl.stp + if loss_mask is not None: + local_valid_tokens = loss_mask.sum() + total_valid_tokens = local_valid_tokens.clone() + all_reduce(total_valid_tokens, op=ReduceOp.SUM, group=group) + else: + local_valid_tokens = logits.shape[0] + total_valid_tokens = local_valid_tokens * group_size + # in the loss function we compute grads w.r.t sum of losses, + # so we need to multiply back by the group size and divide by the number of valid tokens to get the correct scaling + # note, the function returns the sum of local losses, so we need to handle this properly for reporting + grad_output *= group_size / total_valid_tokens # multiply back by the group size + else: + reverse_kl_impl = ReverseKLImpl.no_tp + if target_format == TargetFormat.labels: Assert.eq(target.shape, logits.shape[:-1]) Assert.eq(target.dtype, torch.int64) @@ -339,7 +364,21 @@ def reverse_kl_forward_backward( assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) + # TODO: implement fused? - return _torch_reverse_kl_forward_backward( - logits, target, loss_mask, grad_output, logits_scale_factor, target_format, group, teacher_softmax_temperature + distillation_loss, distillation_grad = REVERSE_KL_IMPLEMENTATIONS[reverse_kl_impl]( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + logits_scale_factor=logits_scale_factor, + target_format=target_format, + teacher_softmax_temperature=teacher_softmax_temperature, + group=group, ) + + if sequence_parallel_logits: + # distillation_loss is local sum, so we need to divide by the number of valid tokens to get the correct scaling + all_reduce(distillation_loss, op=ReduceOp.SUM, group=group) + distillation_loss /= total_valid_tokens # final global loss + return distillation_loss, distillation_grad diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 180785af..d14cc5ed 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -387,7 +387,11 @@ def _logits_cross_entropy_forward_backward( target_format=( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits ), + sequence_parallel_logits=self._sequence_parallel_logits, + group_size=self._distributed_config.tensor_parallel, + vocab_size=self._vocab_dim.global_size, ) + elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( logits.flatten(0, -2), From 28a47e4cbe4efa572050c59cadbdb4ccbf26f3da Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 2 Dec 2025 19:00:49 +0000 Subject: [PATCH 02/10] wip --- fast_llm/functional/cross_entropy.py | 91 +++++++++++++--------------- 1 file changed, 41 insertions(+), 50 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 0ecdddcc..35cf247d 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -222,15 +222,6 @@ def cross_entropy_forward_backward( def distributed_log_softmax( logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 ): - # logits = logits.float() - # local_max = logits.max(dim=dim, keepdim=True)[0] - # all_reduce(local_max, op=ReduceOp.MAX, group=group) - - # logits_shifted = logits - local_max - # exp_logits = torch.exp(logits_shifted) - # sum_exp = exp_logits.sum(dim=dim, keepdim=True) - # all_reduce(sum_exp, op=ReduceOp.SUM, group=group) - # return logits_shifted - sum_exp.log() # log_softmax logits_norm, _, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group=group, dim=dim) return logits_norm - sum_exp_logits.log() # log_softmax @@ -253,6 +244,10 @@ def _torch_reverse_kl_forward_backward( This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case. In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss. + + Takes: + logits: [BxS, V] + target: [BxS, V] """ Assert.eq( teacher_softmax_temperature, @@ -260,8 +255,6 @@ def _torch_reverse_kl_forward_backward( msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL", ) Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL") - # TODO: merge into single function _torch_reverse_kl_forward_backward - Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") Assert.eq(target.shape, logits.shape) assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: @@ -269,25 +262,37 @@ def _torch_reverse_kl_forward_backward( # Compute log probabilities teacher_log_probs = distributed_log_softmax(target.float(), group=group) - batch_size = logits.shape[0] + # batch_size = logits.shape[0] with torch.enable_grad(): logits_ = logits.float().detach().requires_grad_(grad_output is not None) student_log_probs = distributed_log_softmax(logits_, group=group) # Reverse KL: input=teacher_log_probs, target=student_probs - if loss_mask is None: - loss = torch.nn.functional.kl_div( - teacher_log_probs, # input = log(p) - student_log_probs, # target = log(q) - reduction="sum", - log_target=True, - ) + loss_terms = torch.nn.functional.kl_div( + teacher_log_probs, # input = log(p) + student_log_probs, # target = log(q) + reduction="none", + log_target=True, + ).sum(dim=-1) + if loss_mask is not None: + valid = loss_mask.to(loss_terms.dtype) + loss_terms = loss_terms * valid + valid_tokens = valid.sum() else: - raise NotImplementedError("Loss mask not implemented with TP for reverse KL.") + valid_tokens = logits.shape[0] * logits.shape[1] + loss = loss_terms.sum() # sums over batch and seq. len. if group is not None: all_reduce(loss, op=ReduceOp.SUM, group=group) - loss /= batch_size + if loss_mask is not None: + valid_tokens_all = torch.as_tensor(valid_tokens, device=loss.device, dtype=loss.dtype) + all_reduce(valid_tokens_all, op=ReduceOp.SUM, group=group) + valid_tokens = valid_tokens_all + else: + valid_tokens = torch.as_tensor(valid_tokens * group.size(), device=loss.device, dtype=loss.dtype) + else: + valid_tokens = torch.as_tensor(valid_tokens, device=loss.device, dtype=loss.dtype) + loss /= valid_tokens if grad_output is not None: loss.backward(torch.full_like(loss, grad_output)) @@ -329,41 +334,32 @@ def reverse_kl_forward_backward( - Standard CE: KL(p||q) = mode-covering (spreads mass broadly) - Reverse KL: KL(q||p) = mode-seeking (focuses on target modes) + Takes: + logits: [BxS, V], where V is local vocab size + target: [BxS, V] (logits format) + loss_mask: [BxS] or None + ... + Returns: loss: Reverse KL divergence loss grad: Gradients w.r.t. logits """ + # TODO: should we except the possibility of B x S x V shapes? - total_valid_tokens = logits.shape[0] if logits.shape[-1] != vocab_size: reverse_kl_impl = ReverseKLImpl.tp - assert loss_mask is None, "Loss mask is not implemented for vocab-TP in reverse KL yet" elif sequence_parallel_logits: - # the case when vocab_parallel is False and sequence_parallel is True - # grad_output already reflects scaling 1/ number of ranks (group_size), see _forward_backward - reverse_kl_impl = ReverseKLImpl.stp - if loss_mask is not None: - local_valid_tokens = loss_mask.sum() - total_valid_tokens = local_valid_tokens.clone() - all_reduce(total_valid_tokens, op=ReduceOp.SUM, group=group) - else: - local_valid_tokens = logits.shape[0] - total_valid_tokens = local_valid_tokens * group_size - # in the loss function we compute grads w.r.t sum of losses, - # so we need to multiply back by the group size and divide by the number of valid tokens to get the correct scaling - # note, the function returns the sum of local losses, so we need to handle this properly for reporting - grad_output *= group_size / total_valid_tokens # multiply back by the group size + # TODO: see hybrid dev branch where it is implemented + raise NotImplementedError("Sequence-parallel reverse KL is not implemented yet, set vocab_parallel true") else: reverse_kl_impl = ReverseKLImpl.no_tp - if target_format == TargetFormat.labels: - Assert.eq(target.shape, logits.shape[:-1]) - Assert.eq(target.dtype, torch.int64) - else: - Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) + Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") + + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) # TODO: implement fused? distillation_loss, distillation_grad = REVERSE_KL_IMPLEMENTATIONS[reverse_kl_impl]( @@ -376,9 +372,4 @@ def reverse_kl_forward_backward( teacher_softmax_temperature=teacher_softmax_temperature, group=group, ) - - if sequence_parallel_logits: - # distillation_loss is local sum, so we need to divide by the number of valid tokens to get the correct scaling - all_reduce(distillation_loss, op=ReduceOp.SUM, group=group) - distillation_loss /= total_valid_tokens # final global loss return distillation_loss, distillation_grad From 13d0c7d2024b346c8ab3c8eada72ccd6ddf9255f Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 2 Dec 2025 19:48:24 +0000 Subject: [PATCH 03/10] test --- fast_llm/functional/cross_entropy.py | 31 ++--- fast_llm/models/gpt/conversion/llama.py | 8 +- tests/test_rkl_loss.py | 145 ++++++++++++++++++++++++ 3 files changed, 159 insertions(+), 25 deletions(-) create mode 100644 tests/test_rkl_loss.py diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 35cf247d..b01d566e 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -242,12 +242,11 @@ def _torch_reverse_kl_forward_backward( Reverse KL using PyTorch's native kl_div function. This is used for TP version where we split accross vocab dimantion. KL is additive over partitions of the vocab. - This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case. - In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss. - Takes: - logits: [BxS, V] - target: [BxS, V] + logits: [BxS, V] or [B, S, V] + target: [BxS, V] or [B, S, V] (logits format) + loss_mask: [BxS] or [B, S] or None + ... """ Assert.eq( teacher_softmax_temperature, @@ -275,23 +274,16 @@ def _torch_reverse_kl_forward_backward( log_target=True, ).sum(dim=-1) if loss_mask is not None: + # loss mask is the same on all ranks for TP over vocab. valid = loss_mask.to(loss_terms.dtype) loss_terms = loss_terms * valid - valid_tokens = valid.sum() + valid_tokens = torch.tensor(valid.sum(), device=loss_terms.device, dtype=loss_terms.dtype) else: - valid_tokens = logits.shape[0] * logits.shape[1] + valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) loss = loss_terms.sum() # sums over batch and seq. len. if group is not None: all_reduce(loss, op=ReduceOp.SUM, group=group) - if loss_mask is not None: - valid_tokens_all = torch.as_tensor(valid_tokens, device=loss.device, dtype=loss.dtype) - all_reduce(valid_tokens_all, op=ReduceOp.SUM, group=group) - valid_tokens = valid_tokens_all - else: - valid_tokens = torch.as_tensor(valid_tokens * group.size(), device=loss.device, dtype=loss.dtype) - else: - valid_tokens = torch.as_tensor(valid_tokens, device=loss.device, dtype=loss.dtype) loss /= valid_tokens if grad_output is not None: @@ -306,7 +298,6 @@ def _torch_reverse_kl_forward_backward( REVERSE_KL_IMPLEMENTATIONS = { ReverseKLImpl.no_tp: _torch_reverse_kl_forward_backward, ReverseKLImpl.tp: _torch_reverse_kl_forward_backward, - ReverseKLImpl.stp: _torch_reverse_kl_forward_backward, } @@ -335,16 +326,15 @@ def reverse_kl_forward_backward( - Reverse KL: KL(q||p) = mode-seeking (focuses on target modes) Takes: - logits: [BxS, V], where V is local vocab size - target: [BxS, V] (logits format) - loss_mask: [BxS] or None + logits: [BxS, V] or [B, S, V], where V is local vocab size + target: [BxS, V] or [B, S, V] (logits format) + loss_mask: [BxS] or [B, S] or None ... Returns: loss: Reverse KL divergence loss grad: Gradients w.r.t. logits """ - # TODO: should we except the possibility of B x S x V shapes? if logits.shape[-1] != vocab_size: reverse_kl_impl = ReverseKLImpl.tp @@ -355,7 +345,6 @@ def reverse_kl_forward_backward( reverse_kl_impl = ReverseKLImpl.no_tp Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") - Assert.eq(target.shape, logits.shape) assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index b42f68b2..60d8183f 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -207,10 +207,10 @@ def import_config(cls, config: dict) -> dict: elif rope_type == "yarn": rotary_config.update( { - "attention_factor": config["attention_factor"], - "beta_fast": config["beta_fast"], - "beta_slow": config["beta_slow"], - "original_context_length": config["original_max_position_embeddings"], + "attention_factor": config["rope_scaling"]["attention_factor"], + "beta_fast": config["rope_scaling"]["beta_fast"], + "beta_slow": config["rope_scaling"]["beta_slow"], + "original_context_length": config["rope_scaling"]["original_max_position_embeddings"], } ) else: diff --git a/tests/test_rkl_loss.py b/tests/test_rkl_loss.py new file mode 100644 index 00000000..e24227e1 --- /dev/null +++ b/tests/test_rkl_loss.py @@ -0,0 +1,145 @@ +import os +import tempfile + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from fast_llm.functional.config import TargetFormat +from fast_llm.functional.cross_entropy import reverse_kl_forward_backward + + +def _mp_worker(rank: int, world_size: int, init_method: str, fn_name: str, fn_args: tuple): + fn = _WORKERS[fn_name] + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size, init_method=init_method) + try: + fn(rank, dist.group.WORLD, *fn_args) + finally: + dist.destroy_process_group() + + +def _spawn_dist(world_size: int, fn, *fn_args): + """ + Run `fn(rank, group, *fn_args)` across `world_size` ranks using torch.multiprocessing. + """ + with tempfile.NamedTemporaryFile(delete=False) as tmp: + init_method = f"file://{tmp.name}" + + try: + mp.spawn( + _mp_worker, + args=(world_size, init_method, fn.__name__, fn_args), + nprocs=world_size, + join=True, + start_method="spawn", + ) + finally: + if os.path.exists(tmp.name): + os.remove(tmp.name) + + +def _assert_loss_and_grad(logits, loss, grad): + assert isinstance(loss, torch.Tensor) + assert loss.dim() == 0 + assert grad is None or grad.shape == logits.shape + assert torch.isfinite(loss) + if grad is not None: + assert torch.isfinite(grad).all() + + +@pytest.mark.parametrize("use_mask", [False, True]) +def test_reverse_kl_no_tp(use_mask): + torch.manual_seed(0) + batch_size, seq_len, vocab_size = 2, 3, 5 + logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True) + target = torch.randn(batch_size, seq_len, vocab_size) + loss_mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) if use_mask else None + + loss, grad = reverse_kl_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=1.0, + group=None, + target_format=TargetFormat.logits, + sequence_parallel_logits=False, + group_size=None, + vocab_size=vocab_size, + ) + _assert_loss_and_grad(logits, loss, grad) + + # Manual reference: sum over vocab then average over valid tokens. + teacher_log_probs = torch.log_softmax(target, dim=-1) + student_log_probs = torch.log_softmax(logits, dim=-1) + per_sample = torch.nn.functional.kl_div( + teacher_log_probs, student_log_probs, reduction="none", log_target=True + ).sum(dim=-1) + if loss_mask is not None: + per_sample = per_sample * loss_mask + valid_tokens = loss_mask.sum() + else: + valid_tokens = logits.shape[0] * logits.shape[1] + reference = per_sample.sum() / valid_tokens + torch.testing.assert_close(loss, reference, atol=1e-6, rtol=1e-6) + + +def _vocab_tp_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): + torch.manual_seed(0) + world_size = dist.get_world_size(group) + + batch_size, seq_len, vocab_per_rank = 2, 3, 5 + full_vocab = vocab_per_rank * world_size + full_logits = torch.randn(batch_size, seq_len, full_vocab) + full_target = torch.randn(batch_size, seq_len, full_vocab) + full_mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) if use_mask else None + + start = rank * vocab_per_rank + end = start + vocab_per_rank + logits = full_logits[:, :, start:end].clone().requires_grad_(True) + target = full_target[:, :, start:end].clone() + loss_mask = full_mask.clone() if full_mask is not None else None + + loss, grad = reverse_kl_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=None, + group=group, + target_format=TargetFormat.logits, + sequence_parallel_logits=False, + group_size=world_size, + vocab_size=full_vocab, + ) + _assert_loss_and_grad(logits, loss, grad) + + if rank == 0: + ref_loss, _ = reverse_kl_forward_backward( + logits=full_logits.clone(), + target=full_target.clone(), + loss_mask=full_mask.clone() if full_mask is not None else None, + grad_output=None, + group=None, + target_format=TargetFormat.logits, + sequence_parallel_logits=False, + group_size=None, + vocab_size=full_vocab, + ) + else: + ref_loss = torch.zeros_like(loss) + dist.broadcast(ref_loss, src=0, group=group) + torch.testing.assert_close(loss, ref_loss, atol=1e-6, rtol=1e-6) + + +_WORKERS = { + "_vocab_tp_worker": _vocab_tp_worker, +} + + +@pytest.mark.parametrize("use_mask", [True, False]) +def test_reverse_kl_vocab_tp_two_ranks(use_mask): + _spawn_dist(2, _vocab_tp_worker, use_mask) + + +if __name__ == "__main__": + pytest.main([__file__]) From 07d754b3a58a3ee485fe7bed69f28b6c05869e6f Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 2 Dec 2025 21:04:57 +0000 Subject: [PATCH 04/10] comment --- fast_llm/engine/training/config.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index b836516a..867cca98 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -361,8 +361,7 @@ def _validate(self) -> None: # TODO: Add support. Assert.eq(self.model.distributed.pipeline_parallel, 1) # TODO: Check if these work. - # Assert.eq(self.model.distributed.tensor_parallel, 1) - # Assert.eq(self.model.distributed.sequence_data_parallel, 1) + Assert.eq(self.model.distributed.sequence_data_parallel, 1) if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() for reference_model in self.reference_models.values(): From 9390f276aa951641e876d5189810cba40b177811 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 3 Dec 2025 00:05:15 +0000 Subject: [PATCH 05/10] tests + CE loss bug --- fast_llm/functional/cross_entropy.py | 12 ++-- ..._rkl_loss.py => test_distillation_loss.py} | 58 +++++++++++++++++-- 2 files changed, 59 insertions(+), 11 deletions(-) rename tests/{test_rkl_loss.py => test_distillation_loss.py} (68%) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index b01d566e..a03dd3b4 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -144,12 +144,12 @@ def _fused_cross_entropy_forward_backward( all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) else: predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) - - # shouldn't the predicted_logits be scaled by the number of ranks so that the average loss is correct? i.e. - # i.e. on each rank we calculate log Z - sum_i t_i * z_i, z_i is logit. - # then we average: 1/K sum_ranks (log Z - sum_i t_i * z_i) - # = log Z - 1/K sum_ranks (sum_i t_i * z_i) - # but sum_ranks (sum_i t_i * z_i) = sum_i t_i * z_i (over all vocab), so we need to divide predicted_logits by K to match? + if group is not None: + # this is needed because on each rank we calculate log Z - sum_i t_i * z_i, z_i is logit. + # then we average: 1/K sum_ranks (log Z - sum_i t_i * z_i) + # = log Z - 1/K sum_ranks (sum_i t_i * z_i) + # but sum_ranks (sum_i t_i * z_i) = sum_i t_i * z_i (over all vocab) + predicted_logits = predicted_logits * group.size() per_sample_loss = sum_exp_logits.log() - predicted_logits if loss_mask is not None: diff --git a/tests/test_rkl_loss.py b/tests/test_distillation_loss.py similarity index 68% rename from tests/test_rkl_loss.py rename to tests/test_distillation_loss.py index e24227e1..b263a0c9 100644 --- a/tests/test_rkl_loss.py +++ b/tests/test_distillation_loss.py @@ -6,8 +6,8 @@ import torch.distributed as dist import torch.multiprocessing as mp -from fast_llm.functional.config import TargetFormat -from fast_llm.functional.cross_entropy import reverse_kl_forward_backward +from fast_llm.functional.config import CrossEntropyImpl, TargetFormat +from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward def _mp_worker(rank: int, world_size: int, init_method: str, fn_name: str, fn_args: tuple): @@ -131,9 +131,52 @@ def _vocab_tp_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): torch.testing.assert_close(loss, ref_loss, atol=1e-6, rtol=1e-6) -_WORKERS = { - "_vocab_tp_worker": _vocab_tp_worker, -} +def _ce_vocab_tp_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): + torch.manual_seed(0) + world_size = dist.get_world_size(group) + + batch_size, seq_len, vocab_per_rank = 2, 3, 5 + full_vocab = vocab_per_rank * world_size + full_logits = torch.randn(batch_size, seq_len, full_vocab) + full_target = torch.randn(batch_size, seq_len, full_vocab) + full_mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) if use_mask else None + + start = rank * vocab_per_rank + end = start + vocab_per_rank + logits = full_logits[:, :, start:end].clone().requires_grad_(True) + target = full_target[:, :, start:end].clone() + loss_mask = full_mask.clone() if full_mask is not None else None + + loss, grad = cross_entropy_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=None, + group=group, + implementation=CrossEntropyImpl.fused, + target_format=TargetFormat.logits, + logits_scale_factor=1.0, + ) + _assert_loss_and_grad(logits, loss, grad) + + if rank == 0: + ref_loss, _ = cross_entropy_forward_backward( + logits=full_logits.clone(), + target=full_target.clone(), + loss_mask=full_mask.clone() if full_mask is not None else None, + grad_output=None, + group=None, + implementation=CrossEntropyImpl.fused, + target_format=TargetFormat.logits, + logits_scale_factor=1.0, + ) + else: + ref_loss = torch.zeros_like(loss) + dist.broadcast(ref_loss, src=0, group=group) + torch.testing.assert_close(loss, ref_loss, atol=1e-6, rtol=1e-6) + + +_WORKERS = {"_vocab_tp_worker": _vocab_tp_worker, "_ce_vocab_tp_worker": _ce_vocab_tp_worker} @pytest.mark.parametrize("use_mask", [True, False]) @@ -141,5 +184,10 @@ def test_reverse_kl_vocab_tp_two_ranks(use_mask): _spawn_dist(2, _vocab_tp_worker, use_mask) +@pytest.mark.parametrize("use_mask", [True, False]) +def test_cross_entropy_vocab_tp_two_ranks(use_mask): + _spawn_dist(2, _ce_vocab_tp_worker, use_mask) + + if __name__ == "__main__": pytest.main([__file__]) From a3e862c6867d90d898a5f6fdaf8c91f7587a41fb Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 3 Dec 2025 03:35:16 +0000 Subject: [PATCH 06/10] CE loss --- fast_llm/functional/cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index a03dd3b4..79d167fb 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -144,7 +144,7 @@ def _fused_cross_entropy_forward_backward( all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) else: predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) - if group is not None: + if group is not None and target_format != TargetFormat.labels: # this is needed because on each rank we calculate log Z - sum_i t_i * z_i, z_i is logit. # then we average: 1/K sum_ranks (log Z - sum_i t_i * z_i) # = log Z - 1/K sum_ranks (sum_i t_i * z_i) From 16dd5a5e9754c5ec2b8aabcc47e7983c206cb0b2 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 3 Dec 2025 13:22:29 +0000 Subject: [PATCH 07/10] clean ups --- fast_llm/functional/config.py | 6 ------ fast_llm/functional/cross_entropy.py | 26 +++++++------------------ fast_llm/layers/language_model/head.py | 2 -- fast_llm/models/gpt/conversion/llama.py | 8 ++++---- 4 files changed, 11 insertions(+), 31 deletions(-) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 012a04dd..68419384 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -89,12 +89,6 @@ def _set_activation_fn_map() -> None: MAX_DROPLESS_BLOCK_SIZE_ROW = 128 -class ReverseKLImpl(str, enum.Enum): - tp = "tp" - stp = "stp" - no_tp = "no_tp" - - class CrossEntropyImpl(str, enum.Enum): auto = "auto" torch = "torch" diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 79d167fb..0148b208 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -1,7 +1,7 @@ import torch from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce -from fast_llm.functional.config import CrossEntropyImpl, ReverseKLImpl, TargetFormat +from fast_llm.functional.config import CrossEntropyImpl, TargetFormat from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward from fast_llm.utils import Assert @@ -141,14 +141,14 @@ def _fused_cross_entropy_forward_backward( predicted_logits = logits_norm.gather(1, target) if group is not None: predicted_logits = target_mask * predicted_logits + all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) else: predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) if group is not None and target_format != TargetFormat.labels: - # this is needed because on each rank we calculate log Z - sum_i t_i * z_i, z_i is logit. - # then we average: 1/K sum_ranks (log Z - sum_i t_i * z_i) - # = log Z - 1/K sum_ranks (sum_i t_i * z_i) - # but sum_ranks (sum_i t_i * z_i) = sum_i t_i * z_i (over all vocab) + # this is needed because on each rank we calculate log Z - sum_i t_i * z_i, where z_i is logit. + # Then we average on line 160: 1/K sum_ranks (log Z - sum_i t_i * z_i) + # = log Z - 1/K sum_ranks (sum_i t_i * z_i), where is the global predicted_logits, so without multiplying it by K 1/K there does not cancel out. predicted_logits = predicted_logits * group.size() per_sample_loss = sum_exp_logits.log() - predicted_logits @@ -295,12 +295,6 @@ def _torch_reverse_kl_forward_backward( return loss.detach_(), grad -REVERSE_KL_IMPLEMENTATIONS = { - ReverseKLImpl.no_tp: _torch_reverse_kl_forward_backward, - ReverseKLImpl.tp: _torch_reverse_kl_forward_backward, -} - - def reverse_kl_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -311,8 +305,6 @@ def reverse_kl_forward_backward( teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, sequence_parallel_logits: bool = False, - group_size: int = None, - vocab_size: int = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). @@ -336,13 +328,9 @@ def reverse_kl_forward_backward( grad: Gradients w.r.t. logits """ - if logits.shape[-1] != vocab_size: - reverse_kl_impl = ReverseKLImpl.tp - elif sequence_parallel_logits: + if sequence_parallel_logits: # TODO: see hybrid dev branch where it is implemented raise NotImplementedError("Sequence-parallel reverse KL is not implemented yet, set vocab_parallel true") - else: - reverse_kl_impl = ReverseKLImpl.no_tp Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") Assert.eq(target.shape, logits.shape) @@ -351,7 +339,7 @@ def reverse_kl_forward_backward( Assert.eq(loss_mask.shape, logits.shape[:-1]) # TODO: implement fused? - distillation_loss, distillation_grad = REVERSE_KL_IMPLEMENTATIONS[reverse_kl_impl]( + distillation_loss, distillation_grad = _torch_cross_entropy_forward_backward( logits=logits, target=target, loss_mask=loss_mask, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index d14cc5ed..b1d0c2ac 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -388,8 +388,6 @@ def _logits_cross_entropy_forward_backward( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits ), sequence_parallel_logits=self._sequence_parallel_logits, - group_size=self._distributed_config.tensor_parallel, - vocab_size=self._vocab_dim.global_size, ) elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 60d8183f..d8219419 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -198,10 +198,10 @@ def import_config(cls, config: dict) -> dict: elif rope_type == "llama3": rotary_config.update( { - "scale_factor": config["factor"], - "low_frequency_factor": config["low_freq_factor"], - "high_frequency_factor": config["high_freq_factor"], - "original_context_length": config["original_max_position_embeddings"], + "scale_factor": config["rope_scaling"]["factor"], + "low_frequency_factor": config["rope_scaling"]["low_freq_factor"], + "high_frequency_factor": config["rope_scaling"]["high_freq_factor"], + "original_context_length": config["rope_scaling"]["original_max_position_embeddings"], } ) elif rope_type == "yarn": From d002b95409e384db2e4a7941daf18f9f1a5ee0ad Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 3 Dec 2025 13:29:42 +0000 Subject: [PATCH 08/10] clean up --- fast_llm/functional/cross_entropy.py | 2 +- tests/test_distillation_loss.py | 6 ------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 0148b208..42b0c214 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -339,7 +339,7 @@ def reverse_kl_forward_backward( Assert.eq(loss_mask.shape, logits.shape[:-1]) # TODO: implement fused? - distillation_loss, distillation_grad = _torch_cross_entropy_forward_backward( + distillation_loss, distillation_grad = _torch_reverse_kl_forward_backward( logits=logits, target=target, loss_mask=loss_mask, diff --git a/tests/test_distillation_loss.py b/tests/test_distillation_loss.py index b263a0c9..428c4ae5 100644 --- a/tests/test_distillation_loss.py +++ b/tests/test_distillation_loss.py @@ -64,8 +64,6 @@ def test_reverse_kl_no_tp(use_mask): group=None, target_format=TargetFormat.logits, sequence_parallel_logits=False, - group_size=None, - vocab_size=vocab_size, ) _assert_loss_and_grad(logits, loss, grad) @@ -108,8 +106,6 @@ def _vocab_tp_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): group=group, target_format=TargetFormat.logits, sequence_parallel_logits=False, - group_size=world_size, - vocab_size=full_vocab, ) _assert_loss_and_grad(logits, loss, grad) @@ -122,8 +118,6 @@ def _vocab_tp_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): group=None, target_format=TargetFormat.logits, sequence_parallel_logits=False, - group_size=None, - vocab_size=full_vocab, ) else: ref_loss = torch.zeros_like(loss) From 264aaf71ba99f309539539c6a9b65e89c9da416b Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 3 Dec 2025 13:38:48 +0000 Subject: [PATCH 09/10] test --- tests/test_distillation_loss.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/test_distillation_loss.py b/tests/test_distillation_loss.py index 428c4ae5..4d6b9934 100644 --- a/tests/test_distillation_loss.py +++ b/tests/test_distillation_loss.py @@ -10,8 +10,8 @@ from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward -def _mp_worker(rank: int, world_size: int, init_method: str, fn_name: str, fn_args: tuple): - fn = _WORKERS[fn_name] +def _mp_worker(rank: int, world_size: int, init_method: str, fn_args: tuple): + fn = combined_worker dist.init_process_group(backend="gloo", rank=rank, world_size=world_size, init_method=init_method) try: fn(rank, dist.group.WORLD, *fn_args) @@ -29,7 +29,7 @@ def _spawn_dist(world_size: int, fn, *fn_args): try: mp.spawn( _mp_worker, - args=(world_size, init_method, fn.__name__, fn_args), + args=(world_size, init_method, fn_args), nprocs=world_size, join=True, start_method="spawn", @@ -170,17 +170,14 @@ def _ce_vocab_tp_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): torch.testing.assert_close(loss, ref_loss, atol=1e-6, rtol=1e-6) -_WORKERS = {"_vocab_tp_worker": _vocab_tp_worker, "_ce_vocab_tp_worker": _ce_vocab_tp_worker} +def combined_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): + _vocab_tp_worker(rank, group, use_mask) + _ce_vocab_tp_worker(rank, group, use_mask) @pytest.mark.parametrize("use_mask", [True, False]) -def test_reverse_kl_vocab_tp_two_ranks(use_mask): - _spawn_dist(2, _vocab_tp_worker, use_mask) - - -@pytest.mark.parametrize("use_mask", [True, False]) -def test_cross_entropy_vocab_tp_two_ranks(use_mask): - _spawn_dist(2, _ce_vocab_tp_worker, use_mask) +def test_reverse_combined(use_mask): + _spawn_dist(2, combined_worker, use_mask) if __name__ == "__main__": From 3994a447ff8c3bbf80a7f9904ef8193f8246a18b Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 3 Dec 2025 13:39:36 +0000 Subject: [PATCH 10/10] mark slow --- tests/test_distillation_loss.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_distillation_loss.py b/tests/test_distillation_loss.py index 4d6b9934..f4238455 100644 --- a/tests/test_distillation_loss.py +++ b/tests/test_distillation_loss.py @@ -175,6 +175,7 @@ def combined_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): _ce_vocab_tp_worker(rank, group, use_mask) +@pytest.mark.slow @pytest.mark.parametrize("use_mask", [True, False]) def test_reverse_combined(use_mask): _spawn_dist(2, combined_worker, use_mask)