Skip to content

Commit ab24f1e

Browse files
Merge pull request #2585 from AI-Hypercomputer:hengtaoguo-conv
PiperOrigin-RevId: 827661077
2 parents 8d6acdd + b15703f commit ab24f1e

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed

src/MaxText/checkpointing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,9 @@ def _restore_grain_iterator(
427427
elif expansion_factor_real_data > 1 and process_count_stored == process_count_jax // expansion_factor_real_data:
428428
# Scaling up to a larger number of hosts.(e.g., 32 files -> 64 processes)
429429
# In this case, a subset of hosts restore the data iterator.
430-
assert not isinstance(data_iterator, list), "when expansion_factor_real_data > 1, the data iterator should not be a list."
430+
assert not isinstance(
431+
data_iterator, list
432+
), "when expansion_factor_real_data > 1, the data iterator should not be a list."
431433
grain_restore_args = GrainCheckpointRestore(
432434
data_iterator.local_iterator, process_index=jax.process_index(), process_count=process_count_stored
433435
)

src/MaxText/utils/ckpt_conversion/to_maxtext.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,11 @@
4848

4949
import numpy as np
5050
import jax
51+
import psutil
5152
from absl import app
5253
from flax.training import train_state
5354
from transformers import AutoConfig, AutoModelForCausalLM
55+
from tqdm import tqdm
5456

5557
from MaxText import checkpointing
5658
from MaxText import max_logging
@@ -67,6 +69,39 @@
6769
jax.config.update("jax_platform_name", "cpu")
6870

6971

72+
class MemoryMonitorTqdm(tqdm):
73+
"""Custom tqdm class that displays memory usage in the progress bar."""
74+
75+
def format_meter(
76+
self,
77+
n,
78+
total,
79+
elapsed,
80+
postfix=None,
81+
**extra_kwargs,
82+
):
83+
"""Override to add memory usage info to the postfix."""
84+
# Get memory info
85+
memory = psutil.virtual_memory()
86+
used_gb = memory.used / (1024**3)
87+
total_gb = memory.total / (1024**3)
88+
memory_percent = memory.percent
89+
90+
# Create memory postfix
91+
memory_info = f"RAM: {used_gb:.1f}/{total_gb:.1f}GB ({memory_percent:.1f}%)"
92+
93+
# Add memory info to postfix
94+
if postfix:
95+
if isinstance(postfix, dict):
96+
postfix["memory"] = memory_info
97+
else:
98+
postfix = f"{postfix}, {memory_info}"
99+
else:
100+
postfix = memory_info
101+
102+
return super().format_meter(n=n, total=total, elapsed=elapsed, postfix=postfix, **extra_kwargs)
103+
104+
70105
def _build_multi_axis_stacked_tensor(
71106
hf_source_keys: List[List[str]], hf_state_dict: Dict[str, np.ndarray], hook_fns: Any
72107
) -> np.ndarray:
@@ -229,7 +264,9 @@ def main(argv: Sequence[str]) -> None:
229264
max_logging.log("Starting weight transformation...")
230265
final_mt_weights = []
231266

232-
for path_tuple, abstract_leaf_value in abstract_params_flat:
267+
for path_tuple, abstract_leaf_value in MemoryMonitorTqdm(
268+
abstract_params_flat, desc="Transforming weights", unit="param", leave=True, dynamic_ncols=True
269+
):
233270
key_parts = [k.key for k in path_tuple]
234271
mt_param_key = "params-" + "-".join(key_parts)
235272
mt_target_shape_final = abstract_leaf_value.shape

0 commit comments

Comments
 (0)