diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 531bc206..867cca98 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -361,7 +361,6 @@ def _validate(self) -> None: # TODO: Add support. Assert.eq(self.model.distributed.pipeline_parallel, 1) # TODO: Check if these work. - Assert.eq(self.model.distributed.tensor_parallel, 1) Assert.eq(self.model.distributed.sequence_data_parallel, 1) if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index d56dce98..42b0c214 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -141,9 +141,15 @@ def _fused_cross_entropy_forward_backward( predicted_logits = logits_norm.gather(1, target) if group is not None: predicted_logits = target_mask * predicted_logits + all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) else: predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) + if group is not None and target_format != TargetFormat.labels: + # this is needed because on each rank we calculate log Z - sum_i t_i * z_i, where z_i is logit. + # Then we average on line 160: 1/K sum_ranks (log Z - sum_i t_i * z_i) + # = log Z - 1/K sum_ranks (sum_i t_i * z_i), where is the global predicted_logits, so without multiplying it by K 1/K there does not cancel out. + predicted_logits = predicted_logits * group.size() per_sample_loss = sum_exp_logits.log() - predicted_logits if loss_mask is not None: @@ -151,7 +157,7 @@ def _fused_cross_entropy_forward_backward( loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.AVG, group=group) return loss, grad @@ -213,71 +219,72 @@ def cross_entropy_forward_backward( ) +def distributed_log_softmax( + logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 +): + logits_norm, _, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group=group, dim=dim) + + return logits_norm - sum_exp_logits.log() # log_softmax + + def _torch_reverse_kl_forward_backward( logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None, grad_output: float | None, - logits_scale_factor: float, target_format: TargetFormat, group: ProcessGroup | None = None, + logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. - Much simpler and more reliable than custom implementation! + This is used for TP version where we split accross vocab dimantion. KL is additive over partitions of the vocab. + + Takes: + logits: [BxS, V] or [B, S, V] + target: [BxS, V] or [B, S, V] (logits format) + loss_mask: [BxS] or [B, S] or None + ... """ - Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") + Assert.eq( + teacher_softmax_temperature, + 1, + msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL", + ) + Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL") Assert.eq(target.shape, logits.shape) assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) - # Compute log probabilities - let _fused_softmax handle scaling internally - # teacher_probs = _fused_softmax(target, logits_scale_factor * (1 / teacher_softmax_temperature), group) - # # teacher_log_probs = torch.log(teacher_probs + 1e-8) # log(p) - # teacher_probs = torch.clamp(teacher_probs, min=1e-7) # or even 1e-6 - # teacher_log_probs = torch.log(teacher_probs) - - # Scale target logits more carefully - scaled_target = target * (logits_scale_factor / teacher_softmax_temperature) - - # Clamp to prevent extreme values before log_softmax - scaled_target = torch.clamp(scaled_target, min=-50, max=50) - teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) - - # For reverse KL: KL(q||p) = Σ q * log(q/p) = Σ q * (log(q) - log(p)) - # Use kl_div with: input=log(p), target=q, log_target=False - # This gives: Σ q * (log(q) - log(p)) = exactly what we want! - + # Compute log probabilities + teacher_log_probs = distributed_log_softmax(target.float(), group=group) + # batch_size = logits.shape[0] with torch.enable_grad(): - logits_ = logits.detach().requires_grad_(grad_output is not None) - - # Use log_softmax for consistency instead of _fused_softmax - scaled_logits = logits_ * logits_scale_factor - scaled_logits = torch.clamp(scaled_logits, min=-50, max=50) - student_log_probs = torch.log_softmax(scaled_logits, dim=-1) - - # Convert to probabilities for kl_div - # student_probs_ = torch.exp(student_log_probs) + logits_ = logits.float().detach().requires_grad_(grad_output is not None) + student_log_probs = distributed_log_softmax(logits_, group=group) # Reverse KL: input=teacher_log_probs, target=student_probs - if loss_mask is None: - loss = torch.nn.functional.kl_div( - teacher_log_probs, # input = log(p) - student_log_probs, # target = log(q) - reduction="batchmean", - log_target=True, - ) + loss_terms = torch.nn.functional.kl_div( + teacher_log_probs, # input = log(p) + student_log_probs, # target = log(q) + reduction="none", + log_target=True, + ).sum(dim=-1) + if loss_mask is not None: + # loss mask is the same on all ranks for TP over vocab. + valid = loss_mask.to(loss_terms.dtype) + loss_terms = loss_terms * valid + valid_tokens = torch.tensor(valid.sum(), device=loss_terms.device, dtype=loss_terms.dtype) else: - # Apply loss mask - this requires some reshaping - loss_per_sample = torch.nn.functional.kl_div( - teacher_log_probs, student_log_probs, reduction="none", log_target=True - ).sum(dim=-1) - loss = (loss_per_sample * loss_mask).mean() + valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) + loss = loss_terms.sum() # sums over batch and seq. len. - if group is not None and target_format != TargetFormat.labels: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + if group is not None: + all_reduce(loss, op=ReduceOp.SUM, group=group) + loss /= valid_tokens if grad_output is not None: loss.backward(torch.full_like(loss, grad_output)) @@ -297,6 +304,7 @@ def reverse_kl_forward_backward( logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, + sequence_parallel_logits: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). @@ -309,37 +317,36 @@ def reverse_kl_forward_backward( - Standard CE: KL(p||q) = mode-covering (spreads mass broadly) - Reverse KL: KL(q||p) = mode-seeking (focuses on target modes) - Args: - logits: Model predictions [batch_size, ..., vocab_size] - target: Target distribution or labels - loss_mask: Optional mask for loss computation - grad_output: Gradient output scale factor - group: Process group for tensor parallelism - logits_scale_factor: Temperature scaling factor (1/T) - target_format: Format of target (labels or logits) + Takes: + logits: [BxS, V] or [B, S, V], where V is local vocab size + target: [BxS, V] or [B, S, V] (logits format) + loss_mask: [BxS] or [B, S] or None + ... Returns: loss: Reverse KL divergence loss grad: Gradients w.r.t. logits - - Example usage: - # Replace standard cross-entropy with reverse KL - # loss, grad = cross_entropy_forward_backward(logits, target, ...) - loss, grad = reverse_kl_forward_backward(logits, target, - loss_mask=None, - grad_output=1.0, - logits_scale_factor=1.0, - target_format=TargetFormat.labels) """ - if target_format == TargetFormat.labels: - Assert.eq(target.shape, logits.shape[:-1]) - Assert.eq(target.dtype, torch.int64) - else: - Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) + + if sequence_parallel_logits: + # TODO: see hybrid dev branch where it is implemented + raise NotImplementedError("Sequence-parallel reverse KL is not implemented yet, set vocab_parallel true") + + Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) + # TODO: implement fused? - return _torch_reverse_kl_forward_backward( - logits, target, loss_mask, grad_output, logits_scale_factor, target_format, group, teacher_softmax_temperature + distillation_loss, distillation_grad = _torch_reverse_kl_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + logits_scale_factor=logits_scale_factor, + target_format=target_format, + teacher_softmax_temperature=teacher_softmax_temperature, + group=group, ) + return distillation_loss, distillation_grad diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 180785af..b1d0c2ac 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -387,7 +387,9 @@ def _logits_cross_entropy_forward_backward( target_format=( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits ), + sequence_parallel_logits=self._sequence_parallel_logits, ) + elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( logits.flatten(0, -2), diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index b42f68b2..d8219419 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -198,19 +198,19 @@ def import_config(cls, config: dict) -> dict: elif rope_type == "llama3": rotary_config.update( { - "scale_factor": config["factor"], - "low_frequency_factor": config["low_freq_factor"], - "high_frequency_factor": config["high_freq_factor"], - "original_context_length": config["original_max_position_embeddings"], + "scale_factor": config["rope_scaling"]["factor"], + "low_frequency_factor": config["rope_scaling"]["low_freq_factor"], + "high_frequency_factor": config["rope_scaling"]["high_freq_factor"], + "original_context_length": config["rope_scaling"]["original_max_position_embeddings"], } ) elif rope_type == "yarn": rotary_config.update( { - "attention_factor": config["attention_factor"], - "beta_fast": config["beta_fast"], - "beta_slow": config["beta_slow"], - "original_context_length": config["original_max_position_embeddings"], + "attention_factor": config["rope_scaling"]["attention_factor"], + "beta_fast": config["rope_scaling"]["beta_fast"], + "beta_slow": config["rope_scaling"]["beta_slow"], + "original_context_length": config["rope_scaling"]["original_max_position_embeddings"], } ) else: diff --git a/tests/test_distillation_loss.py b/tests/test_distillation_loss.py new file mode 100644 index 00000000..f4238455 --- /dev/null +++ b/tests/test_distillation_loss.py @@ -0,0 +1,185 @@ +import os +import tempfile + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from fast_llm.functional.config import CrossEntropyImpl, TargetFormat +from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward + + +def _mp_worker(rank: int, world_size: int, init_method: str, fn_args: tuple): + fn = combined_worker + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size, init_method=init_method) + try: + fn(rank, dist.group.WORLD, *fn_args) + finally: + dist.destroy_process_group() + + +def _spawn_dist(world_size: int, fn, *fn_args): + """ + Run `fn(rank, group, *fn_args)` across `world_size` ranks using torch.multiprocessing. + """ + with tempfile.NamedTemporaryFile(delete=False) as tmp: + init_method = f"file://{tmp.name}" + + try: + mp.spawn( + _mp_worker, + args=(world_size, init_method, fn_args), + nprocs=world_size, + join=True, + start_method="spawn", + ) + finally: + if os.path.exists(tmp.name): + os.remove(tmp.name) + + +def _assert_loss_and_grad(logits, loss, grad): + assert isinstance(loss, torch.Tensor) + assert loss.dim() == 0 + assert grad is None or grad.shape == logits.shape + assert torch.isfinite(loss) + if grad is not None: + assert torch.isfinite(grad).all() + + +@pytest.mark.parametrize("use_mask", [False, True]) +def test_reverse_kl_no_tp(use_mask): + torch.manual_seed(0) + batch_size, seq_len, vocab_size = 2, 3, 5 + logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True) + target = torch.randn(batch_size, seq_len, vocab_size) + loss_mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) if use_mask else None + + loss, grad = reverse_kl_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=1.0, + group=None, + target_format=TargetFormat.logits, + sequence_parallel_logits=False, + ) + _assert_loss_and_grad(logits, loss, grad) + + # Manual reference: sum over vocab then average over valid tokens. + teacher_log_probs = torch.log_softmax(target, dim=-1) + student_log_probs = torch.log_softmax(logits, dim=-1) + per_sample = torch.nn.functional.kl_div( + teacher_log_probs, student_log_probs, reduction="none", log_target=True + ).sum(dim=-1) + if loss_mask is not None: + per_sample = per_sample * loss_mask + valid_tokens = loss_mask.sum() + else: + valid_tokens = logits.shape[0] * logits.shape[1] + reference = per_sample.sum() / valid_tokens + torch.testing.assert_close(loss, reference, atol=1e-6, rtol=1e-6) + + +def _vocab_tp_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): + torch.manual_seed(0) + world_size = dist.get_world_size(group) + + batch_size, seq_len, vocab_per_rank = 2, 3, 5 + full_vocab = vocab_per_rank * world_size + full_logits = torch.randn(batch_size, seq_len, full_vocab) + full_target = torch.randn(batch_size, seq_len, full_vocab) + full_mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) if use_mask else None + + start = rank * vocab_per_rank + end = start + vocab_per_rank + logits = full_logits[:, :, start:end].clone().requires_grad_(True) + target = full_target[:, :, start:end].clone() + loss_mask = full_mask.clone() if full_mask is not None else None + + loss, grad = reverse_kl_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=None, + group=group, + target_format=TargetFormat.logits, + sequence_parallel_logits=False, + ) + _assert_loss_and_grad(logits, loss, grad) + + if rank == 0: + ref_loss, _ = reverse_kl_forward_backward( + logits=full_logits.clone(), + target=full_target.clone(), + loss_mask=full_mask.clone() if full_mask is not None else None, + grad_output=None, + group=None, + target_format=TargetFormat.logits, + sequence_parallel_logits=False, + ) + else: + ref_loss = torch.zeros_like(loss) + dist.broadcast(ref_loss, src=0, group=group) + torch.testing.assert_close(loss, ref_loss, atol=1e-6, rtol=1e-6) + + +def _ce_vocab_tp_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): + torch.manual_seed(0) + world_size = dist.get_world_size(group) + + batch_size, seq_len, vocab_per_rank = 2, 3, 5 + full_vocab = vocab_per_rank * world_size + full_logits = torch.randn(batch_size, seq_len, full_vocab) + full_target = torch.randn(batch_size, seq_len, full_vocab) + full_mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) if use_mask else None + + start = rank * vocab_per_rank + end = start + vocab_per_rank + logits = full_logits[:, :, start:end].clone().requires_grad_(True) + target = full_target[:, :, start:end].clone() + loss_mask = full_mask.clone() if full_mask is not None else None + + loss, grad = cross_entropy_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=None, + group=group, + implementation=CrossEntropyImpl.fused, + target_format=TargetFormat.logits, + logits_scale_factor=1.0, + ) + _assert_loss_and_grad(logits, loss, grad) + + if rank == 0: + ref_loss, _ = cross_entropy_forward_backward( + logits=full_logits.clone(), + target=full_target.clone(), + loss_mask=full_mask.clone() if full_mask is not None else None, + grad_output=None, + group=None, + implementation=CrossEntropyImpl.fused, + target_format=TargetFormat.logits, + logits_scale_factor=1.0, + ) + else: + ref_loss = torch.zeros_like(loss) + dist.broadcast(ref_loss, src=0, group=group) + torch.testing.assert_close(loss, ref_loss, atol=1e-6, rtol=1e-6) + + +def combined_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): + _vocab_tp_worker(rank, group, use_mask) + _ce_vocab_tp_worker(rank, group, use_mask) + + +@pytest.mark.slow +@pytest.mark.parametrize("use_mask", [True, False]) +def test_reverse_combined(use_mask): + _spawn_dist(2, combined_worker, use_mask) + + +if __name__ == "__main__": + pytest.main([__file__])