Skip to content

Conversation

@khatwanimohit
Copy link
Collaborator

@khatwanimohit khatwanimohit commented Nov 7, 2025

Description

This PR introduces a lazy_load configuration option that reduces peak RAM usage by deferring the loading and transformation of weights until the exact moment Orbax needs to save them to disk.

Note: lazy loading feature is temporarily disabled for multimodal models because the key names in the files are different from what are seen when loading a HF model using AutoConfig.

Key Changes

  • Lazy Loading Mechanism: Implemented LazyHFLoader and LazyTensor classes. Instead of eagerly loading all weights, we now create lightweight proxy objects that only load their specific tensor data from disk when Orbax calls array() during the saving phase.
  • Orbax Integration: Registered a custom LazyTensorHandler with Orbax to allow it to correctly recognize and process our proxy objects as if they were standard NumPy arrays.
  • Config: Added a new boolean flag lazy_load to the MaxText configuration to enable this mode. It defaults to False to preserve existing behavior for smaller models.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/458745828

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

x here denotes number of B parameters

Llama.3.1-70B

old RAM usage: 616 GB (8.8x GB)
new RAM usage: 86.2 GB (1.2x GB)
Logs: https://paste.googleplex.com/6288510168465408#l=571

LLama3.1-8B

old RAM usage: 51 GB (6.3x GB)
new RAM usage: 31GB (4x GB)
Logs: https://paste.googleplex.com/6128000672333824
forward pass logits test: https://paste.googleplex.com/5595374018494464

Qwen3-4B

old RAM usage: 37 GB ( 9.2x GB)
new RAM usage: 15.7 GB ( 3.7x GB )
Logs: https://paste.googleplex.com/4933249042350080#l=543
forward pass logit test: https://paste.googleplex.com/5515770054443008

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@github-actions
Copy link

🤖 Hi @khatwanimohit, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request introduces a lazy_load feature for checkpoint conversion, which significantly reduces peak memory usage. The implementation is well-structured, introducing LazyHFLoader and LazyTensor classes to handle on-demand loading of model weights, and integrating them with Orbax through a custom LazyTensorHandler.

🔍 General Feedback

  • The code is clean, well-documented, and includes helpful additions like RAM usage logging and a memory-monitoring progress bar.
  • The separate handling for safetensors and PyTorch binary files is robust.
  • The use of functools.partial to create loading functions is elegant.

I've left a few minor suggestions for typos and docstring updates. Overall, this is an excellent contribution that will be very beneficial for users working with large models.

Copy link
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding this nice function to significantly improve the memory! If you have the time comparison, would be great to include in the PR description.

@khatwanimohit khatwanimohit force-pushed the mohit/memory_opt branch 3 times, most recently from 81eb54e to 10cbbb0 Compare November 12, 2025 21:13
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@copybara-service copybara-service bot merged commit 9642e89 into main Nov 12, 2025
35 checks passed
@copybara-service copybara-service bot deleted the mohit/memory_opt branch November 12, 2025 22:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants