Skip to content

Commit 8c0b72b

Browse files
justinchubyCopilot
andauthored
Add a verbose mode to torch api for external data save (#2643)
Show progress bar with tqdm when verbose is True. It will be enabled in PyTorch 2.10 <img width="2261" height="171" alt="image" src="https://github.com/user-attachments/assets/3184a813-d6d3-4bbc-93ef-78fec481a36c" /> --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 04a9da4 commit 8c0b72b

File tree

1 file changed

+29
-2
lines changed

1 file changed

+29
-2
lines changed

onnxscript/_framework_apis/torch_2_5.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
]
1414

1515
import dataclasses
16+
import importlib.util
1617
import os
1718
import pathlib
1819
from typing import Callable
@@ -63,7 +64,9 @@ def check_model(model: ir.Model) -> None:
6364
del model # Unused yet
6465

6566

66-
def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike) -> None:
67+
def save_model_with_external_data(
68+
model: ir.Model, model_path: str | os.PathLike, verbose: bool = False
69+
) -> None:
6770
"""Save the model with external data. The model is unchanged after saving."""
6871

6972
# TODO(#1835): Decide if we want to externalize large attributes as well
@@ -78,7 +81,31 @@ def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike
7881
destination_path = pathlib.Path(model_path)
7982
data_path = f"{destination_path.name}.data"
8083

81-
ir.save(model, model_path, external_data=data_path)
84+
# Show a progress bar if verbose is True and tqdm is installed
85+
use_tqdm = verbose and importlib.util.find_spec("tqdm") is not None
86+
87+
if use_tqdm:
88+
import tqdm # pylint: disable=import-outside-toplevel
89+
90+
with tqdm.tqdm() as pbar:
91+
total_set = False
92+
93+
def callback(
94+
tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo
95+
) -> None:
96+
nonlocal total_set
97+
if not total_set:
98+
pbar.total = metadata.total
99+
total_set = True
100+
101+
pbar.update()
102+
pbar.set_description(
103+
f"Saving {tensor.name} ({tensor.dtype.short_name()}, {tensor.shape}) at offset {metadata.offset}"
104+
)
105+
106+
ir.save(model, model_path, external_data=data_path, callback=callback)
107+
else:
108+
ir.save(model, model_path, external_data=data_path)
82109

83110

84111
def get_torchlib_ops() -> list[_OnnxFunctionMeta]:

0 commit comments

Comments
 (0)