-
Notifications
You must be signed in to change notification settings - Fork 39
TP support for reverse KL loss #400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
oleksost
wants to merge
5
commits into
main
Choose a base branch
from
rev_kl_tp
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+297
β76
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| import torch | ||
|
|
||
| from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce | ||
| from fast_llm.functional.config import CrossEntropyImpl, TargetFormat | ||
| from fast_llm.functional.config import CrossEntropyImpl, ReverseKLImpl, TargetFormat | ||
| from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward | ||
| from fast_llm.utils import Assert | ||
|
|
||
|
|
@@ -144,14 +144,20 @@ def _fused_cross_entropy_forward_backward( | |
| all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) | ||
| else: | ||
| predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) | ||
| if group is not None: | ||
| # this is needed because on each rank we calculate log Z - sum_i t_i * z_i, z_i is logit. | ||
| # then we average: 1/K sum_ranks (log Z - sum_i t_i * z_i) | ||
| # = log Z - 1/K sum_ranks (sum_i t_i * z_i) | ||
| # but sum_ranks (sum_i t_i * z_i) = sum_i t_i * z_i (over all vocab) | ||
| predicted_logits = predicted_logits * group.size() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks wrong, see previous comment. The previous version was tested and confirmed to work. |
||
|
|
||
| per_sample_loss = sum_exp_logits.log() - predicted_logits | ||
| if loss_mask is not None: | ||
| per_sample_loss = per_sample_loss * loss_mask | ||
|
|
||
| loss = per_sample_loss.mean() | ||
| if target_format != TargetFormat.labels and group is not None: | ||
| all_reduce(loss, op=ReduceOp.MEAN, group=group) | ||
| all_reduce(loss, op=ReduceOp.AVG, group=group) | ||
|
|
||
| return loss, grad | ||
|
|
||
|
|
@@ -213,71 +219,72 @@ def cross_entropy_forward_backward( | |
| ) | ||
|
|
||
|
|
||
| def distributed_log_softmax( | ||
| logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 | ||
| ): | ||
| logits_norm, _, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group=group, dim=dim) | ||
|
|
||
| return logits_norm - sum_exp_logits.log() # log_softmax | ||
|
|
||
|
|
||
| def _torch_reverse_kl_forward_backward( | ||
| logits: torch.Tensor, | ||
| target: torch.Tensor, | ||
| loss_mask: torch.Tensor | None, | ||
| grad_output: float | None, | ||
| logits_scale_factor: float, | ||
| target_format: TargetFormat, | ||
| group: ProcessGroup | None = None, | ||
| logits_scale_factor: float = 1.0, | ||
| teacher_softmax_temperature: float = 1.0, | ||
| **kwargs, | ||
| ) -> tuple[torch.Tensor, torch.Tensor | None]: | ||
| """ | ||
| Reverse KL using PyTorch's native kl_div function. | ||
| Much simpler and more reliable than custom implementation! | ||
| This is used for TP version where we split accross vocab dimantion. KL is additive over partitions of the vocab. | ||
|
|
||
| Takes: | ||
| logits: [BxS, V] or [B, S, V] | ||
| target: [BxS, V] or [B, S, V] (logits format) | ||
| loss_mask: [BxS] or [B, S] or None | ||
| ... | ||
| """ | ||
| Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") | ||
| Assert.eq( | ||
| teacher_softmax_temperature, | ||
| 1, | ||
| msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL", | ||
| ) | ||
| Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL") | ||
| Assert.eq(target.shape, logits.shape) | ||
| assert target.dtype.is_floating_point, target.dtype | ||
| if loss_mask is not None: | ||
| Assert.eq(loss_mask.shape, logits.shape[:-1]) | ||
|
|
||
| # Compute log probabilities - let _fused_softmax handle scaling internally | ||
| # teacher_probs = _fused_softmax(target, logits_scale_factor * (1 / teacher_softmax_temperature), group) | ||
| # # teacher_log_probs = torch.log(teacher_probs + 1e-8) # log(p) | ||
| # teacher_probs = torch.clamp(teacher_probs, min=1e-7) # or even 1e-6 | ||
| # teacher_log_probs = torch.log(teacher_probs) | ||
|
|
||
| # Scale target logits more carefully | ||
| scaled_target = target * (logits_scale_factor / teacher_softmax_temperature) | ||
|
|
||
| # Clamp to prevent extreme values before log_softmax | ||
| scaled_target = torch.clamp(scaled_target, min=-50, max=50) | ||
| teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) | ||
|
|
||
| # For reverse KL: KL(q||p) = Ξ£ q * log(q/p) = Ξ£ q * (log(q) - log(p)) | ||
| # Use kl_div with: input=log(p), target=q, log_target=False | ||
| # This gives: Ξ£ q * (log(q) - log(p)) = exactly what we want! | ||
|
|
||
| # Compute log probabilities | ||
| teacher_log_probs = distributed_log_softmax(target.float(), group=group) | ||
| # batch_size = logits.shape[0] | ||
| with torch.enable_grad(): | ||
| logits_ = logits.detach().requires_grad_(grad_output is not None) | ||
|
|
||
| # Use log_softmax for consistency instead of _fused_softmax | ||
| scaled_logits = logits_ * logits_scale_factor | ||
| scaled_logits = torch.clamp(scaled_logits, min=-50, max=50) | ||
| student_log_probs = torch.log_softmax(scaled_logits, dim=-1) | ||
|
|
||
| # Convert to probabilities for kl_div | ||
| # student_probs_ = torch.exp(student_log_probs) | ||
| logits_ = logits.float().detach().requires_grad_(grad_output is not None) | ||
| student_log_probs = distributed_log_softmax(logits_, group=group) | ||
|
|
||
| # Reverse KL: input=teacher_log_probs, target=student_probs | ||
| if loss_mask is None: | ||
| loss = torch.nn.functional.kl_div( | ||
| teacher_log_probs, # input = log(p) | ||
| student_log_probs, # target = log(q) | ||
| reduction="batchmean", | ||
| log_target=True, | ||
| ) | ||
| loss_terms = torch.nn.functional.kl_div( | ||
| teacher_log_probs, # input = log(p) | ||
| student_log_probs, # target = log(q) | ||
| reduction="none", | ||
| log_target=True, | ||
| ).sum(dim=-1) | ||
| if loss_mask is not None: | ||
| # loss mask is the same on all ranks for TP over vocab. | ||
| valid = loss_mask.to(loss_terms.dtype) | ||
| loss_terms = loss_terms * valid | ||
| valid_tokens = torch.tensor(valid.sum(), device=loss_terms.device, dtype=loss_terms.dtype) | ||
| else: | ||
| # Apply loss mask - this requires some reshaping | ||
| loss_per_sample = torch.nn.functional.kl_div( | ||
| teacher_log_probs, student_log_probs, reduction="none", log_target=True | ||
| ).sum(dim=-1) | ||
| loss = (loss_per_sample * loss_mask).mean() | ||
| valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) | ||
| loss = loss_terms.sum() # sums over batch and seq. len. | ||
|
|
||
| if group is not None and target_format != TargetFormat.labels: | ||
| all_reduce(loss, op=ReduceOp.MEAN, group=group) | ||
| if group is not None: | ||
| all_reduce(loss, op=ReduceOp.SUM, group=group) | ||
| loss /= valid_tokens | ||
|
|
||
| if grad_output is not None: | ||
| loss.backward(torch.full_like(loss, grad_output)) | ||
|
|
@@ -288,6 +295,12 @@ def _torch_reverse_kl_forward_backward( | |
| return loss.detach_(), grad | ||
|
|
||
|
|
||
| REVERSE_KL_IMPLEMENTATIONS = { | ||
| ReverseKLImpl.no_tp: _torch_reverse_kl_forward_backward, | ||
| ReverseKLImpl.tp: _torch_reverse_kl_forward_backward, | ||
| } | ||
|
|
||
|
|
||
| def reverse_kl_forward_backward( | ||
| logits: torch.Tensor, | ||
| target: torch.Tensor, | ||
|
|
@@ -297,6 +310,9 @@ def reverse_kl_forward_backward( | |
| logits_scale_factor: float = 1.0, | ||
| teacher_softmax_temperature: float = 1.0, | ||
| target_format: TargetFormat = TargetFormat.labels, | ||
| sequence_parallel_logits: bool = False, | ||
| group_size: int = None, | ||
| vocab_size: int = None, | ||
| ) -> tuple[torch.Tensor, torch.Tensor | None]: | ||
| """ | ||
| Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). | ||
|
|
@@ -309,37 +325,40 @@ def reverse_kl_forward_backward( | |
| - Standard CE: KL(p||q) = mode-covering (spreads mass broadly) | ||
| - Reverse KL: KL(q||p) = mode-seeking (focuses on target modes) | ||
|
|
||
| Args: | ||
| logits: Model predictions [batch_size, ..., vocab_size] | ||
| target: Target distribution or labels | ||
| loss_mask: Optional mask for loss computation | ||
| grad_output: Gradient output scale factor | ||
| group: Process group for tensor parallelism | ||
| logits_scale_factor: Temperature scaling factor (1/T) | ||
| target_format: Format of target (labels or logits) | ||
| Takes: | ||
| logits: [BxS, V] or [B, S, V], where V is local vocab size | ||
| target: [BxS, V] or [B, S, V] (logits format) | ||
| loss_mask: [BxS] or [B, S] or None | ||
| ... | ||
|
|
||
| Returns: | ||
| loss: Reverse KL divergence loss | ||
| grad: Gradients w.r.t. logits | ||
|
|
||
| Example usage: | ||
| # Replace standard cross-entropy with reverse KL | ||
| # loss, grad = cross_entropy_forward_backward(logits, target, ...) | ||
| loss, grad = reverse_kl_forward_backward(logits, target, | ||
| loss_mask=None, | ||
| grad_output=1.0, | ||
| logits_scale_factor=1.0, | ||
| target_format=TargetFormat.labels) | ||
| """ | ||
| if target_format == TargetFormat.labels: | ||
| Assert.eq(target.shape, logits.shape[:-1]) | ||
| Assert.eq(target.dtype, torch.int64) | ||
|
|
||
| if logits.shape[-1] != vocab_size: | ||
| reverse_kl_impl = ReverseKLImpl.tp | ||
| elif sequence_parallel_logits: | ||
| # TODO: see hybrid dev branch where it is implemented | ||
| raise NotImplementedError("Sequence-parallel reverse KL is not implemented yet, set vocab_parallel true") | ||
| else: | ||
| Assert.eq(target.shape, logits.shape) | ||
| assert target.dtype.is_floating_point, target.dtype | ||
| if loss_mask is not None: | ||
| Assert.eq(loss_mask.shape, logits.shape[:-1]) | ||
| reverse_kl_impl = ReverseKLImpl.no_tp | ||
|
|
||
| Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") | ||
| Assert.eq(target.shape, logits.shape) | ||
| assert target.dtype.is_floating_point, target.dtype | ||
| if loss_mask is not None: | ||
| Assert.eq(loss_mask.shape, logits.shape[:-1]) | ||
|
|
||
| # TODO: implement fused? | ||
| return _torch_reverse_kl_forward_backward( | ||
| logits, target, loss_mask, grad_output, logits_scale_factor, target_format, group, teacher_softmax_temperature | ||
| distillation_loss, distillation_grad = REVERSE_KL_IMPLEMENTATIONS[reverse_kl_impl]( | ||
| logits=logits, | ||
| target=target, | ||
| loss_mask=loss_mask, | ||
| grad_output=grad_output, | ||
| logits_scale_factor=logits_scale_factor, | ||
| target_format=target_format, | ||
| teacher_softmax_temperature=teacher_softmax_temperature, | ||
| group=group, | ||
| ) | ||
| return distillation_loss, distillation_grad | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -207,10 +207,10 @@ def import_config(cls, config: dict) -> dict: | |
| elif rope_type == "yarn": | ||
| rotary_config.update( | ||
| { | ||
| "attention_factor": config["attention_factor"], | ||
| "beta_fast": config["beta_fast"], | ||
| "beta_slow": config["beta_slow"], | ||
| "original_context_length": config["original_max_position_embeddings"], | ||
| "attention_factor": config["rope_scaling"]["attention_factor"], | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be the same for llama 3 above? |
||
| "beta_fast": config["rope_scaling"]["beta_fast"], | ||
| "beta_slow": config["rope_scaling"]["beta_slow"], | ||
| "original_context_length": config["rope_scaling"]["original_max_position_embeddings"], | ||
| } | ||
| ) | ||
| else: | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed, there is only one implementation.