-
Notifications
You must be signed in to change notification settings - Fork 432
Reduce memory requirements for checkpoint conversion #2636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
c3e55a8 to
e99a0bd
Compare
|
🤖 Hi @khatwanimohit, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📋 Review Summary
This pull request introduces a lazy_load feature for checkpoint conversion, which significantly reduces peak memory usage. The implementation is well-structured, introducing LazyHFLoader and LazyTensor classes to handle on-demand loading of model weights, and integrating them with Orbax through a custom LazyTensorHandler.
🔍 General Feedback
- The code is clean, well-documented, and includes helpful additions like RAM usage logging and a memory-monitoring progress bar.
- The separate handling for
safetensorsand PyTorch binary files is robust. - The use of
functools.partialto create loading functions is elegant.
I've left a few minor suggestions for typos and docstring updates. Overall, this is an excellent contribution that will be very beneficial for users working with large models.
shuningjin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for adding this nice function to significantly improve the memory! If you have the time comparison, would be great to include in the PR description.
81eb54e to
10cbbb0
Compare
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
10cbbb0 to
14df4ae
Compare
Description
This PR introduces a lazy_load configuration option that reduces peak RAM usage by deferring the loading and transformation of weights until the exact moment Orbax needs to save them to disk.
Note: lazy loading feature is temporarily disabled for multimodal models because the key names in the files are different from what are seen when loading a HF model using AutoConfig.
Key Changes
If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/458745828
Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.
Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.
Tests
x here denotes number of B parameters
Llama.3.1-70B
old RAM usage: 616 GB (8.8x GB)
new RAM usage: 86.2 GB (1.2x GB)
Logs: https://paste.googleplex.com/6288510168465408#l=571
LLama3.1-8B
old RAM usage: 51 GB (6.3x GB)
new RAM usage: 31GB (4x GB)
Logs: https://paste.googleplex.com/6128000672333824
forward pass logits test: https://paste.googleplex.com/5595374018494464
Qwen3-4B
old RAM usage: 37 GB ( 9.2x GB)
new RAM usage: 15.7 GB ( 3.7x GB )
Logs: https://paste.googleplex.com/4933249042350080#l=543
forward pass logit test: https://paste.googleplex.com/5515770054443008
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.