Skip to content

Conversation

@oleksost
Copy link
Contributor

@oleksost oleksost commented Dec 2, 2025

TP support for reverse KL loss.

  • adds support for vocabulary parallel reverse KL loss calculation using torch (no fused implementataion).
  • Sequence parallel loss calculation is not supported to keep it simple (I don't think we use sequence parallel embeddings/head)
  • this also fixes a small bug in CE loss for when it is used for distillation

TODO:

  • test_rkl_loss.py should probably be integrated in the testing framework?

🔍 Type of change

Select all that apply:

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

📝 Changes

  • added _torch_reverse_kl_forward_backward in cross_entropy.py
  • added test_rkl_loss

✅ Checklist

Make sure the following tasks are completed before submitting the PR:

General

  • 📜 I have read and followed the contributing guidelines.
  • 🏷️ I am using a clear and descriptive PR title that summarizes the key change or feature introduced.
  • 🎉 The functionality is complete, and I have tested the changes.
  • 📝 I have updated the documentation if needed.
  • ⚠️ The change does not introduce any new issues (e.g., runtime warnings, type checker errors, linting problems, unhandled edge cases).
  • 🧩 I have commented my code, especially in hard-to-understand areas.

Dependencies and Configuration

  • 🐋 I have updated the Docker configuration or dependencies, if applicable.
  • 🔄 I have ensured compatibility with the existing setup after dependency changes.

Testing

  • 🧪 I have added or updated tests to cover my changes.
  • ✔️ New and existing tests pass locally with my changes.
  • 🚦 I have tested these changes on GPUs and verified training stability.
  • 🏋️ I have tested the changes on realistic training workloads, if applicable.

Performance Impact

  • 📊 I have run benchmarks where applicable to evaluate the performance impact.
  • ✅ The benchmarks show no performance regression.
  • 🚀 The benchmarks indicate a potential performance improvement.
  • ⚠️ The benchmarks indicate a potential performance degradation.
  • 📈 I have provided benchmark results and detailed any performance impact below, if applicable.

📊 Performance Impact Details

If there is any impact on performance, describe it and provide benchmark results, if applicable:


🗒️ Additional Notes

Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.

@oleksost oleksost marked this pull request as ready for review December 2, 2025 19:52
else:
predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True)

# shouldn't the predicted_logits be scaled by the number of ranks so that the average loss is correct? i.e.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no averaging. We calculate log Z - sum_ranks (sum_i t_i * z_i_rank * mask) which is the same as log Z - sum_i t_i * z_i.

MAX_DROPLESS_BLOCK_SIZE_ROW = 128


class ReverseKLImpl(str, enum.Enum):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed, there is only one implementation.

"beta_fast": config["beta_fast"],
"beta_slow": config["beta_slow"],
"original_context_length": config["original_max_position_embeddings"],
"attention_factor": config["rope_scaling"]["attention_factor"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be the same for llama 3 above?


@pytest.mark.parametrize("use_mask", [True, False])
def test_cross_entropy_vocab_tp_two_ranks(use_mask):
_spawn_dist(2, _ce_vocab_tp_worker, use_mask)
Copy link
Collaborator

@jlamypoirier jlamypoirier Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to mark as slow (how long does this take?). Does it work on cpu? Also it would be better to run all tests in the same spawn call because of the huge distributed overhead.

# then we average: 1/K sum_ranks (log Z - sum_i t_i * z_i)
# = log Z - 1/K sum_ranks (sum_i t_i * z_i)
# but sum_ranks (sum_i t_i * z_i) = sum_i t_i * z_i (over all vocab)
predicted_logits = predicted_logits * group.size()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks wrong, see previous comment. The previous version was tested and confirmed to work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants