-
Notifications
You must be signed in to change notification settings - Fork 39
TP support for reverse KL loss #400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
fast_llm/functional/cross_entropy.py
Outdated
| else: | ||
| predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) | ||
|
|
||
| # shouldn't the predicted_logits be scaled by the number of ranks so that the average loss is correct? i.e. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no averaging. We calculate log Z - sum_ranks (sum_i t_i * z_i_rank * mask) which is the same as log Z - sum_i t_i * z_i.
| MAX_DROPLESS_BLOCK_SIZE_ROW = 128 | ||
|
|
||
|
|
||
| class ReverseKLImpl(str, enum.Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed, there is only one implementation.
| "beta_fast": config["beta_fast"], | ||
| "beta_slow": config["beta_slow"], | ||
| "original_context_length": config["original_max_position_embeddings"], | ||
| "attention_factor": config["rope_scaling"]["attention_factor"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be the same for llama 3 above?
|
|
||
| @pytest.mark.parametrize("use_mask", [True, False]) | ||
| def test_cross_entropy_vocab_tp_two_ranks(use_mask): | ||
| _spawn_dist(2, _ce_vocab_tp_worker, use_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to mark as slow (how long does this take?). Does it work on cpu? Also it would be better to run all tests in the same spawn call because of the huge distributed overhead.
| # then we average: 1/K sum_ranks (log Z - sum_i t_i * z_i) | ||
| # = log Z - 1/K sum_ranks (sum_i t_i * z_i) | ||
| # but sum_ranks (sum_i t_i * z_i) = sum_i t_i * z_i (over all vocab) | ||
| predicted_logits = predicted_logits * group.size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks wrong, see previous comment. The previous version was tested and confirmed to work.
TP support for reverse KL loss.
TODO:
test_rkl_loss.pyshould probably be integrated in the testing framework?🔍 Type of change
Select all that apply:
📝 Changes
✅ Checklist
Make sure the following tasks are completed before submitting the PR:
General
Dependencies and Configuration
Testing
Performance Impact
📊 Performance Impact Details
If there is any impact on performance, describe it and provide benchmark results, if applicable:
🗒️ Additional Notes
Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.