|
48 | 48 |
|
49 | 49 | import numpy as np |
50 | 50 | import jax |
| 51 | +import psutil |
51 | 52 | from absl import app |
52 | 53 | from flax.training import train_state |
53 | 54 | from transformers import AutoConfig, AutoModelForCausalLM |
| 55 | +from tqdm import tqdm |
54 | 56 |
|
55 | 57 | from MaxText import checkpointing |
56 | 58 | from MaxText import max_logging |
|
67 | 69 | jax.config.update("jax_platform_name", "cpu") |
68 | 70 |
|
69 | 71 |
|
| 72 | +class MemoryMonitorTqdm(tqdm): |
| 73 | + """Custom tqdm class that displays memory usage in the progress bar.""" |
| 74 | + |
| 75 | + def format_meter( |
| 76 | + self, |
| 77 | + n, |
| 78 | + total, |
| 79 | + elapsed, |
| 80 | + postfix=None, |
| 81 | + **extra_kwargs, |
| 82 | + ): |
| 83 | + """Override to add memory usage info to the postfix.""" |
| 84 | + # Get memory info |
| 85 | + memory = psutil.virtual_memory() |
| 86 | + used_gb = memory.used / (1024**3) |
| 87 | + total_gb = memory.total / (1024**3) |
| 88 | + memory_percent = memory.percent |
| 89 | + |
| 90 | + # Create memory postfix |
| 91 | + memory_info = f"RAM: {used_gb:.1f}/{total_gb:.1f}GB ({memory_percent:.1f}%)" |
| 92 | + |
| 93 | + # Add memory info to postfix |
| 94 | + if postfix: |
| 95 | + if isinstance(postfix, dict): |
| 96 | + postfix["memory"] = memory_info |
| 97 | + else: |
| 98 | + postfix = f"{postfix}, {memory_info}" |
| 99 | + else: |
| 100 | + postfix = memory_info |
| 101 | + |
| 102 | + return super().format_meter(n=n, total=total, elapsed=elapsed, postfix=postfix, **extra_kwargs) |
| 103 | + |
| 104 | + |
70 | 105 | def _build_multi_axis_stacked_tensor( |
71 | 106 | hf_source_keys: List[List[str]], hf_state_dict: Dict[str, np.ndarray], hook_fns: Any |
72 | 107 | ) -> np.ndarray: |
@@ -229,7 +264,9 @@ def main(argv: Sequence[str]) -> None: |
229 | 264 | max_logging.log("Starting weight transformation...") |
230 | 265 | final_mt_weights = [] |
231 | 266 |
|
232 | | - for path_tuple, abstract_leaf_value in abstract_params_flat: |
| 267 | + for path_tuple, abstract_leaf_value in MemoryMonitorTqdm( |
| 268 | + abstract_params_flat, desc="Transforming weights", unit="param", leave=True, dynamic_ncols=True |
| 269 | + ): |
233 | 270 | key_parts = [k.key for k in path_tuple] |
234 | 271 | mt_param_key = "params-" + "-".join(key_parts) |
235 | 272 | mt_target_shape_final = abstract_leaf_value.shape |
|
0 commit comments