From 1eeedc4fc9e03b68180f2f7ce20da7d4ddba9389 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 17 Nov 2025 10:29:17 +0000 Subject: [PATCH 1/4] Better side-by-side --- _unittests/ut_torch_onnx/test_sbs.py | 47 +++- onnx_diagnostic/_command_lines_parser.py | 12 +- onnx_diagnostic/torch_onnx/sbs.py | 296 +++++++++++------------ 3 files changed, 187 insertions(+), 168 deletions(-) diff --git a/_unittests/ut_torch_onnx/test_sbs.py b/_unittests/ut_torch_onnx/test_sbs.py index 6ae72e94..d44c46cc 100644 --- a/_unittests/ut_torch_onnx/test_sbs.py +++ b/_unittests/ut_torch_onnx/test_sbs.py @@ -10,7 +10,7 @@ ) from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str -from onnx_diagnostic.torch_onnx.sbs import run_aligned, post_process_run_aligned_obs +from onnx_diagnostic.torch_onnx.sbs import run_aligned, RunAlignedRecord from onnx_diagnostic.export.api import to_onnx @@ -21,6 +21,24 @@ def setUpClass(cls): cls.torch = torch + def test_run_aligned_record(self): + r = RunAlignedRecord( + ep_id_node=-1, + onnx_id_node=-1, + ep_name="A", + onnx_name="B", + ep_target="C", + onnx_op_type="D", + shape_type="E", + err_abs=0.1, + err_rel=0.2, + err_dev=0.3, + err_nan=0.4, + ) + sr = str(r) + self.assertIn("RunAlignedRecord(", sr) + self.assertIn("shape_type='E'", sr) + @hide_stdout() @unittest.skipIf(to_onnx is None, "to_onnx not installed") @ignore_errors(OSError) # connectivity issues @@ -48,7 +66,7 @@ def forward(self, x): run_cls=ExtendedReferenceEvaluator, atol=1e-5, rtol=1e-5, - verbose=1, + verbose=10, ), ) self.assertEqual(len(results), 7) @@ -83,7 +101,7 @@ def forward(self, x): run_cls=ExtendedReferenceEvaluator, atol=1e-5, rtol=1e-5, - verbose=1, + verbose=10, ), ) self.assertEqual(len(results), 6) @@ -115,7 +133,7 @@ def forward(self, x): run_cls=ExtendedReferenceEvaluator, atol=1e-5, rtol=1e-5, - verbose=1, + verbose=10, ), ) self.assertEqual(len(results), 6) @@ -285,7 +303,10 @@ def forward(self, x): ), ) self.assertEqual(len(results), 14) - self.assertEqual([r[-1].get("dev", 0) for r in results], [0] * 14) + self.assertEqual( + [r.err_dev for r in results], + [None, None, None, None, None, None, None, None, 0, 0, 0, 0, 0, 0], + ) @hide_stdout() @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) @@ -323,7 +344,7 @@ def forward(self, x): use_tensor=True, ), ) - df = pandas.DataFrame(list(map(post_process_run_aligned_obs, results))) + df = pandas.DataFrame(list(results)) df.to_excel(self.get_dump_file("test_sbs_model_with_weights_custom.xlsx")) self.assertEqual( [ @@ -332,6 +353,7 @@ def forward(self, x): "ep_target", "err_abs", "err_dev", + "err_nan", "err_rel", "onnx_id_node", "onnx_name", @@ -341,7 +363,10 @@ def forward(self, x): sorted(df.columns), ) self.assertEqual(len(results), 12) - self.assertEqual([r[-1].get("dev", 0) for r in results], [0] * 12) + self.assertEqual( + [r.err_dev for r in results], + [None, None, None, None, None, None, None, None, None, 0, 0, 0], + ) self.assertEqual( [-1.0, -1.0, -1.0, -1.0, -10.0, -10.0, -10.0, -10.0, -1.0, 0.0, 1.0, 2.0], df["onnx_id_node"].fillna(-10).tolist(), @@ -384,7 +409,7 @@ def forward(self, x): use_tensor=True, ), ) - df = pandas.DataFrame(list(map(post_process_run_aligned_obs, results))) + df = pandas.DataFrame(list(results)) df.to_excel(self.get_dump_file("test_sbs_model_with_weights_dynamo.xlsx")) self.assertEqual( [ @@ -393,6 +418,7 @@ def forward(self, x): "ep_target", "err_abs", "err_dev", + "err_nan", "err_rel", "onnx_id_node", "onnx_name", @@ -402,7 +428,10 @@ def forward(self, x): sorted(df.columns), ) self.assertEqual(len(results), 12) - self.assertEqual([r[-1].get("dev", 0) for r in results], [0] * 12) + self.assertEqual( + [r.err_dev for r in results], + [None, None, None, None, None, None, None, None, None, 0, 0, 0], + ) self.assertEqual( [-1.0, -1.0, -1.0, -1.0, -10.0, -10.0, -10.0, -10.0, -1.0, 0.0, 1.0, 2.0], df["onnx_id_node"].fillna(-10).tolist(), diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 3f61915e..dbd40f67 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1169,9 +1169,9 @@ def get_parser_sbs() -> ArgumentParser: parser.add_argument( "-r", "--ratio", - default=5, + default=100, required=False, - help="Saves the result in an excel file every node.", + help="Saves the result in an excel file every nodes.", ) return parser @@ -1244,10 +1244,14 @@ def _size(name): pobs = post_process_run_aligned_obs(obs) data.append(pobs) if "initializer" not in pobs and "placeholder" not in pobs and len(data) % ratio == 0: - df = pandas.DataFrame(data) + df = pandas.DataFrame(data).apply( + lambda col: col.fillna("") if col.dtype == "object" else col + ) df.to_excel(args.output) print(f"-- final saves into {args.output!r}") - df = pandas.DataFrame(data) + df = pandas.DataFrame(data).apply( + lambda col: col.fillna("") if col.dtype == "object" else col + ) df.to_excel(args.output) print("-- done") diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 30576a4a..738cbe3b 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -1,4 +1,5 @@ import inspect +from dataclasses import dataclass from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union import onnx import onnx.helper as oh @@ -6,7 +7,7 @@ import torch from ..helpers import string_type, string_diff, max_diff, flatten_object from ..helpers.onnx_helper import pretty_onnx -from ..helpers.torch_helper import to_numpy, from_numpy +from ..helpers.torch_helper import to_numpy, from_numpy, to_tensor def validate_fx_tensor( @@ -170,42 +171,31 @@ def prepare_args_kwargs( return new_args, new_kwargs -def post_process_run_aligned_obs( - obs: Tuple[ - Optional[int], - Optional[int], - Optional[str], - Optional[str], - Optional[str], - Optional[str], - Dict[str, Optional[Union[int, float]]], - ], -) -> Dict[str, Optional[Union[str, float, int]]]: - """ - Flattens an observations produced by function - :func:`onnx_diagnostic.torch_onnx.sbs.run_aligned`. - """ - dobs = dict( - zip( - [ - "ep_id_node", - "onnx_id_node", - "ep_name", - "onnx_name", - "ep_target", - "onnx_op_type", - "shape_type", - ], - obs, - ) - ) - if "abs" in obs[-1] and obs[-1]["abs"] is not None: - dobs["err_abs"] = obs[-1]["abs"] # type: ignore[assignment] - if "rel" in obs[-1] and obs[-1]["rel"] is not None: - dobs["err_rel"] = obs[-1]["rel"] # type: ignore[assignment] - if "dev" in obs[-1] and obs[-1]["dev"] is not None: - dobs["err_dev"] = obs[-1]["dev"] # type: ignore[assignment] - return dobs # type: ignore[return-value] +@dataclass +class RunAlignedRecord: + ep_id_node: Optional[int] = None + onnx_id_node: Optional[int] = None + ep_name: Optional[str] = None + onnx_name: Optional[str] = None + ep_target: Optional[str] = None + onnx_op_type: Optional[str] = None + shape_type: Optional[str] = None + err_abs: Optional[float] = None + err_rel: Optional[float] = None + err_dev: Optional[float] = None + err_nan: Optional[float] = None + + def set_diff(self, diff: Dict[str, Any]): + if diff is None: + return + if "abs" in diff: + self.err_abs = diff["abs"] + if "rel" in diff: + self.err_rel = diff["rel"] + if "dev" in diff: + self.err_dev = diff["dev"] + if "nan" in diff: + self.err_nan = diff["nan"] def run_aligned( @@ -229,7 +219,7 @@ def run_aligned( rtol: Optional[float] = None, verbose: int = 0, exc: bool = True, -) -> Iterator[Tuple[Any, ...]]: +) -> Iterator[RunAlignedRecord]: """ Runs in parallel both the exported program and the onnx proto and looks for discrepancies. @@ -247,18 +237,7 @@ def run_aligned( :param rtol: relative tolerance :param verbose: verbosity level :param exc: stops if an exception - :return: a list of tuples containing the results, they come in tuple - - Each tuple is: - - - ep_id_node - - onnx_id_node - - ep_name - - onnx_name - - ep target name - - onnx op _type - - ep or onnx shape and type - - difference + :return: a list of :class:`RunAlignedRecord` Example: @@ -269,13 +248,10 @@ def run_aligned( import pandas import torch from onnx_diagnostic.reference import ( - # This can be replace by any runtime taking NodeProto as an input. + # This can be replaced by any runtime taking NodeProto as an input. ExtendedReferenceEvaluator as ReferenceEvaluator, ) - from onnx_diagnostic.torch_onnx.sbs import ( - run_aligned, - post_process_run_aligned_obs, - ) + from onnx_diagnostic.torch_onnx.sbs import run_aligned class Model(torch.nn.Module): @@ -296,16 +272,12 @@ def forward(self, x): Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},) ).model_proto results = list( - map( - post_process_run_aligned_obs, - run_aligned( - ep, onx, ReferenceEvaluator, (x,), atol=1e-5, rtol=1e-5, verbose=1 - ), - ), + run_aligned(ep, onx, ReferenceEvaluator, (x,), atol=1e-5, rtol=1e-5, verbose=1) ) print("------------") print("final results") df = pandas.DataFrame(results) + df = df.apply(lambda col: col.fillna("") if col.dtype == "object" else col) print(df) @@ -361,10 +333,7 @@ def forward(self, x): import pandas import onnx import torch - from onnx_diagnostic.torch_onnx.sbs import ( - run_aligned, - post_process_run_aligned_obs, - ) + from onnx_diagnostic.torch_onnx.sbs import run_aligned from onnx_diagnostic.reference import OnnxruntimeEvaluator @@ -374,23 +343,21 @@ def forward(self, x): results = list( - map( - post_process_run_aligned_obs, - run_aligned( - ep, - onx, - OnnxruntimeEvaluator, - inputs, - atol=1e-5, - rtol=1e-5, - verbose=1, - use_tensor=True, - ), - ), + run_aligned( + ep, + onx, + OnnxruntimeEvaluator, + inputs, + atol=1e-5, + rtol=1e-5, + verbose=1, + use_tensor=True, + ) ) print("------------") print("final results") df = pandas.DataFrame(results) + df = df.apply(lambda col: col.fillna("") if col.dtype == "object" else col) print(df) A command line can also be run: @@ -469,13 +436,13 @@ def _loop_cmp( i_onnx, ): onnx_results[o] = _check_tensor_(o, r) - if verbose: + if verbose > 1: print(f"[run_aligned-nx] +res: {o}={string_type(r, **str_kws)}") to = mapping_onnx_to_torch.get(o, o) if to in torch_results: d = max_diff(torch_results[to], r) - if verbose: + if verbose > 1: if o == to: print(f"[run_aligned-==] cmp {to}: {string_diff(d)}") else: @@ -493,7 +460,15 @@ def _loop_cmp( ) else: print(f"[run_align-dx] discrepancies {string_diff(d)} - [{to}/{o}]") - return (i, i_onnx, o, to, string_type(torch_results[to], **str_kws), d) + r = RunAlignedRecord( + ep_id_node=i, + onnx_id_node=i_onnx, + ep_name=o, + onnx_name=to, + shape_type=string_type(torch_results[to], **str_kws), + ) + r.set_diff(d) + return r return None if verbose: @@ -517,34 +492,29 @@ def _loop_cmp( if verbose: print(f"[run_aligned] handles {len(onx.graph.initializer)} initializers from onnx") + memory_cpu = 0 + memory_cuda = 0 onnx_results: Dict[str, Any] = {} for init in onx.graph.initializer: # type: ignore positions[init.name] = -1 - t = run_cls( - _make_node_from_initializer(init), - **run_cls_kwargs, - ).run( # type: ignore[attr-defined] - None, {} - )[ - 0 - ] + t = to_tensor(init) if default_device and t.numel() >= 1024: # Let's force its way to cuda (should check the device has well). t = t.to(default_device) + size = t.element_size() * t.numel() + if t.is_cuda: + memory_cuda += size + else: + memory_cpu += size onnx_results[init.name] = _check_tensor_(init.name, t, flip_type=True) - if init.name.startswith("init"): - # not a weight - continue if verbose: - print(f"[run_aligned] handles common {len(onnx_results)} initializer from torch") + print(f"[run_aligned] handled {len(onnx_results)} initializers from onnx") + print(f"[run_aligned] onnx memory cpu {memory_cpu / 2**20:.3f} Mb") + print(f"[run_aligned] onnx memory cuda {memory_cuda / 2**20:.3f} Mb") # we should be careful, torch may modified inplace the weights, # it may be difficult to share weights torch_results: Dict[str, Any] = {} - if verbose: - print( - f"[run_aligned] handles other constant from {len(ep.graph.nodes)} nodes from torch" - ) last_position = 0 torch_output_names = None for node in ep.graph.nodes: @@ -560,13 +530,16 @@ def _loop_cmp( mapping_onnx_to_torch = dict(zip(onnx_outputs_names, torch_output_names)) if verbose: - print(f"[run_aligned] torch {len(torch_results)} constants") print(f"[run_aligned] onnx {len(onnx_results)} constants") - print(f"[run_aligned] common {len(mapping_onnx_to_torch)} constants") - for k, v in torch_results.items(): - print(f"[run_aligned-ep] +cst: {k}: {string_type(v, **str_kws)}") - for k, v in onnx_results.items(): - print(f"[run_aligned-nx] +ini: {k}: {string_type(v, **str_kws)}") + print(f"[run_aligned] onnx {len(onx.graph.input)} inputs") + print(f"[run_aligned] onnx {len(onx.graph.output)} outputs") + print(f"[run_aligned] common {len(mapping_onnx_to_torch)} outputs") + print(f"[run_aligned] run_cls_kwargs={run_cls_kwargs}") + if verbose > 1: + for k, v in torch_results.items(): + print(f"[run_aligned-ep] +cst: {k}: {string_type(v, **str_kws)}") + for k, v in onnx_results.items(): + print(f"[run_aligned-nx] +ini: {k}: {string_type(v, **str_kws)}") onnx_args = list(args) if args else [] if kwargs: @@ -589,19 +562,25 @@ def _loop_cmp( } for n in onnx_results: if n not in placeholders: - yield ( - None, - -1, - None, - n, - None, - "initializer", - string_type(onnx_results[n], **str_kws), - {}, + yield RunAlignedRecord( + onnx_id_node=-1, + onnx_name=n, + onnx_op_type="initializer", + shape_type=string_type(onnx_results[n], **str_kws), ) ep_graph_nodes = list(ep.graph.nodes) - for i, node in enumerate(ep_graph_nodes): - if verbose: + + if verbose == 1: + import tqdm + + loop = tqdm.tqdm(list(enumerate(ep_graph_nodes))) + else: + loop = list(enumerate(ep_graph_nodes)) + + yielded_nodes = 0 + max_abs = 0 + for i, node in loop: + if verbose > 1: if node.op == "call_function": print( f"[run_aligned] run ep.graph.nodes[{i}]: " @@ -609,6 +588,11 @@ def _loop_cmp( ) else: print(f"[run_aligned] run ep.graph.nodes[{i}]: {node.op} -> {node.name!r}") + elif verbose == 1: + loop.set_description( + f"ep {i}/{len(ep_graph_nodes)} nx {last_position}/{len(onx.graph.node)} " + f"mapped {yielded_nodes} maxabs {max_abs:1.5f}" + ) if node.op == "placeholder": is_input = node.name in placeholders @@ -622,44 +606,40 @@ def _loop_cmp( if use_tensor else torch.from_numpy(onnx_results[node.name]) ) - if verbose: + if verbose > 1: t = torch_results[node.name] print(f"[run_aligned-ep] =plh: {node.name}={string_type(t, **str_kws)}") # Otherwise, it is an input. - yield ( - -1, - -1, - node.name, - node.name, - "input" if is_input else "placeholder", - "input" if is_input else "initializer", - string_type(t, **str_kws), - ( - {} - if is_input - else max_diff( + record = RunAlignedRecord( + ep_id_node=-1, + onnx_id_node=-1, + ep_name=node.name, + onnx_name=node.name, + ep_target=("input" if is_input else "placeholder"), + onnx_op_type=("input" if is_input else "initializer"), + shape_type=string_type(t, **str_kws), + ) + if not is_input: + record.set_diff( + max_diff( ep_state_dict[placeholders_to_state_dict[node.name]], onnx_results[node.name], ) - ), - ) + ) + yield record else: assert node.name in placeholders_to_state_dict, ( f"Unable to find placeholder {node.name!r} in " f"{sorted(placeholders_to_state_dict)}" ) torch_results[node.name] = ep_state_dict[placeholders_to_state_dict[node.name]] - if verbose: + if verbose > 1: print(f"[run_aligned-ep] +plh: {node.name}={string_type(t, **str_kws)}") - yield ( - -1, - None, - node.name, - None, - "placeholder", - None, - string_type(torch_results[node.name], **str_kws), - {}, + yield RunAlignedRecord( + ep_id_node=-1, + ep_name=node.name, + ep_target="placeholder", + shape_type=string_type(torch_results[node.name], **str_kws), ) continue @@ -675,7 +655,7 @@ def _loop_cmp( for k, v in zip(outputs, new_outputs): torch_results[k] = v - if verbose: + if verbose > 1: for k, v in zip(outputs, new_outputs): print(f"[run_aligned-ep] +res: {k}={string_type(v, **str_kws)}") @@ -689,7 +669,7 @@ def _loop_cmp( for i_onnx in range(last_position, max_pos + 1): node = onx.graph.node[i_onnx] - if verbose: + if verbose > 1: print( f"[run_aligned] run onx.graph.node[{i_onnx}]: " f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}" @@ -734,21 +714,24 @@ def _loop_cmp( i_onnx, ) if tmp is not None: - yield ( - *tmp[:4], - str(ep_graph_nodes[tmp[0]].target), - onx.graph.node[tmp[1]].op_type, - *tmp[-2:], - ) + tmp.ep_target = str(ep_graph_nodes[tmp.ep_id_node].target) + tmp.onnx_op_type = onx.graph.node[tmp.onnx_id_node].op_type + yielded_nodes += 1 + if tmp.err_abs is not None: + max_abs = max(max_abs, tmp.err_abs) + yield tmp last_position = max_pos + 1 # complete the execution of the onnx graph if verbose: - print(f"[run_aligned] complete execution of onnx graph from pos={last_position}") + print( + f"[run_aligned] complete execution of onnx graph from pos={last_position} " + f"to {len(onx.graph.node)}" + ) for i_onnx in range(last_position, len(onx.graph.node)): node = onx.graph.node[i_onnx] - if verbose: + if verbose > 1: print( f"[run_aligned] run onx.graph.node[{i_onnx}]: " f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}" @@ -770,9 +753,12 @@ def _loop_cmp( i_onnx, ) if tmp is not None: - yield ( - *tmp[:4], - str(ep_graph_nodes[tmp[0]].target), - onx.graph.node[tmp[1]].op_type, - *tmp[-2:], - ) + tmp.ep_target = str(ep_graph_nodes[tmp.ep_id_node].target) + tmp.onnx_op_type = onx.graph.node[tmp.onnx_id_node].op_type + yielded_nodes += 1 + if tmp.err_abs is not None: + max_abs = max(max_abs, tmp.err_abs) + yield tmp + if verbose: + print(f"[run_aligned] done with {yielded_nodes} mapped nodes") + print(f"[run_aligned] max absolution error={max_abs}") From 9e47e33fc207dfe20f87515a9670ed8e2fb4655f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 17 Nov 2025 10:31:40 +0000 Subject: [PATCH 2/4] fix command lien --- onnx_diagnostic/_command_lines_parser.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index dbd40f67..2b1643a6 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1180,7 +1180,7 @@ def _cmd_sbs(argv: List[Any]): import pandas import torch from .helpers import string_type - from .torch_onnx.sbs import run_aligned, post_process_run_aligned_obs + from .torch_onnx.sbs import run_aligned from .reference import OnnxruntimeEvaluator parser = get_parser_sbs() @@ -1241,9 +1241,12 @@ def _size(name): use_tensor=True, exc=False, ): - pobs = post_process_run_aligned_obs(obs) - data.append(pobs) - if "initializer" not in pobs and "placeholder" not in pobs and len(data) % ratio == 0: + data.append(obs) + if ( + obs.onnx_op_type != "initializer" + and onnx.ep_target != "placeholder" + and len(data) % ratio == 0 + ): df = pandas.DataFrame(data).apply( lambda col: col.fillna("") if col.dtype == "object" else col ) From f6c684ef98774716229500fd0e88d58e679d3168 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 17 Nov 2025 11:06:44 +0000 Subject: [PATCH 3/4] last changes --- CHANGELOGS.rst | 2 +- onnx_diagnostic/_command_lines_parser.py | 16 +++++++++++++--- onnx_diagnostic/torch_onnx/sbs.py | 11 +++++++++-- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 96ba2c91..dbeb31bf 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,7 +4,7 @@ Change Logs 0.8.3 +++++ -* :pr:`304`: improves side-by-side comparison +* :pr:`304`, :pr:`306`: improves side-by-side comparison 0.8.2 +++++ diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 2b1643a6..d129b748 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1114,10 +1114,20 @@ def get_parser_sbs() -> ArgumentParser: the exported onnx model. It assumes some names are common. The execution of the exported program and the onnx model are done in parallel. The device is the one used to store the - model and the inputs.s + model and the inputs. + Where do discrepancies start? This function tries to answer that question. + """ + ), + epilog=textwrap.dedent( + """ + The command line expects the following files to be saved with + the following function. inputs is a dictionary of the input of the model. + + - torch.export.save(ep: torch.export.ExportedProgram) + - torch.save(**inputs) + - onnx.save(...) """ ), - epilog="Where do discrepancies start? This function tries to answer that question.", ) parser.add_argument( "-i", @@ -1244,7 +1254,7 @@ def _size(name): data.append(obs) if ( obs.onnx_op_type != "initializer" - and onnx.ep_target != "placeholder" + and obs.ep_target != "placeholder" and len(data) % ratio == 0 ): df = pandas.DataFrame(data).apply( diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 738cbe3b..3261db90 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -674,6 +674,11 @@ def _loop_cmp( f"[run_aligned] run onx.graph.node[{i_onnx}]: " f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}" ) + elif verbose == 1: + loop.set_description( + f"ep {i}/{len(ep_graph_nodes)} nx {last_position}/{len(onx.graph.node)} " + f"mapped {yielded_nodes} maxabs {max_abs:1.5f}" + ) ref = run_cls(node, **run_cls_kwargs) feeds = {k: onnx_results[k] for k in node.input} res = ref.run(None, feeds) # type: ignore[attr-defined] @@ -700,7 +705,8 @@ def _loop_cmp( f"res={string_type(res, with_device=True, with_shape=True)}, " f"node is {pretty_onnx(node)}" ) - for o, r in zip(node.output, res): + node_output = [o for o in node.output if o] + for o, r in zip(node_output, res): tmp = _loop_cmp( mapping_onnx_to_torch, torch_results, @@ -739,7 +745,8 @@ def _loop_cmp( ref = run_cls(node, **run_cls_kwargs) feeds = {k: onnx_results[k] for k in node.input} res = ref.run(None, feeds) # type: ignore[attr-defined] - for o, r in zip(node.output, res): + node_output = [o for o in node.output if o] + for o, r in zip(node_output, res): tmp = _loop_cmp( mapping_onnx_to_torch, torch_results, From 3c10ccad41b7416a44ad1a636b62bf7ab039e02c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 17 Nov 2025 12:18:38 +0000 Subject: [PATCH 4/4] small changes --- _unittests/ut_torch_onnx/test_sbs.py | 7 ++++- onnx_diagnostic/torch_onnx/sbs.py | 46 ++++++++++++++++++++++++---- 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/_unittests/ut_torch_onnx/test_sbs.py b/_unittests/ut_torch_onnx/test_sbs.py index d44c46cc..55f873d6 100644 --- a/_unittests/ut_torch_onnx/test_sbs.py +++ b/_unittests/ut_torch_onnx/test_sbs.py @@ -200,7 +200,6 @@ def forward(self, x): ), ) self.assertEqual(len(results), 8) - self.clean_dump() @hide_stdout() @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) @@ -351,13 +350,16 @@ def forward(self, x): "ep_id_node", "ep_name", "ep_target", + "ep_time_run", "err_abs", "err_dev", "err_nan", "err_rel", "onnx_id_node", + "onnx_id_output", "onnx_name", "onnx_op_type", + "onnx_time_run", "shape_type", ], sorted(df.columns), @@ -416,13 +418,16 @@ def forward(self, x): "ep_id_node", "ep_name", "ep_target", + "ep_time_run", "err_abs", "err_dev", "err_nan", "err_rel", "onnx_id_node", + "onnx_id_output", "onnx_name", "onnx_op_type", + "onnx_time_run", "shape_type", ], sorted(df.columns), diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 3261db90..b304397c 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -1,4 +1,5 @@ import inspect +import time from dataclasses import dataclass from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union import onnx @@ -179,11 +180,14 @@ class RunAlignedRecord: onnx_name: Optional[str] = None ep_target: Optional[str] = None onnx_op_type: Optional[str] = None + onnx_id_output: Optional[int] = None shape_type: Optional[str] = None err_abs: Optional[float] = None err_rel: Optional[float] = None err_dev: Optional[float] = None err_nan: Optional[float] = None + ep_time_run: Optional[float] = None + onnx_time_run: Optional[float] = None def set_diff(self, diff: Dict[str, Any]): if diff is None: @@ -514,12 +518,20 @@ def _loop_cmp( print(f"[run_aligned] onnx memory cuda {memory_cuda / 2**20:.3f} Mb") # we should be careful, torch may modified inplace the weights, # it may be difficult to share weights + ep_graph_nodes = list(ep.graph.nodes) torch_results: Dict[str, Any] = {} last_position = 0 torch_output_names = None - for node in ep.graph.nodes: + name_to_ep_node = {} + for i, node in enumerate(ep_graph_nodes): if node.op == "output": torch_output_names = [n.name for n in node.args[0]] + assert isinstance(node.name, str), ( + f"Unexpected type {type(node.name)} for node={node} (target={node.target}), " + f"args={node.args}" + ) + name_to_ep_node[node.name] = i + onnx_outputs_names = [o.name for o in onx.graph.output] assert torch_output_names is not None and len(torch_output_names) == len( onnx_outputs_names @@ -568,7 +580,6 @@ def _loop_cmp( onnx_op_type="initializer", shape_type=string_type(onnx_results[n], **str_kws), ) - ep_graph_nodes = list(ep.graph.nodes) if verbose == 1: import tqdm @@ -577,6 +588,7 @@ def _loop_cmp( else: loop = list(enumerate(ep_graph_nodes)) + ep_durations = {} yielded_nodes = 0 max_abs = 0 for i, node in loop: @@ -645,7 +657,10 @@ def _loop_cmp( outputs = [node.name] if isinstance(node.name, str) else list(node.name) args, kwargs = prepare_args_kwargs(torch_results, node) + begin = time.perf_counter() new_outputs = run_fx_node(node, args, kwargs) + duration = time.perf_counter() - begin + ep_durations[i] = duration if isinstance(new_outputs, (torch.Tensor, int, float, list, tuple)): new_outputs = (new_outputs,) @@ -676,12 +691,14 @@ def _loop_cmp( ) elif verbose == 1: loop.set_description( - f"ep {i}/{len(ep_graph_nodes)} nx {last_position}/{len(onx.graph.node)} " + f"ep {i}/{len(ep_graph_nodes)} nx {i_onnx}/{len(onx.graph.node)} " f"mapped {yielded_nodes} maxabs {max_abs:1.5f}" ) ref = run_cls(node, **run_cls_kwargs) feeds = {k: onnx_results[k] for k in node.input} + begin = time.perf_counter() res = ref.run(None, feeds) # type: ignore[attr-defined] + duration = time.perf_counter() - begin assert ( not has_cuda or not any(t is not None and t.is_cuda for t in feeds.values()) @@ -705,7 +722,8 @@ def _loop_cmp( f"res={string_type(res, with_device=True, with_shape=True)}, " f"node is {pretty_onnx(node)}" ) - node_output = [o for o in node.output if o] + list_node_output = list(node.output) + node_output = [o for o in list_node_output if o] for o, r in zip(node_output, res): tmp = _loop_cmp( mapping_onnx_to_torch, @@ -720,8 +738,12 @@ def _loop_cmp( i_onnx, ) if tmp is not None: + tmp.ep_id_node = name_to_ep_node[tmp.ep_name] tmp.ep_target = str(ep_graph_nodes[tmp.ep_id_node].target) tmp.onnx_op_type = onx.graph.node[tmp.onnx_id_node].op_type + tmp.onnx_id_output = list_node_output.index(o) + tmp.ep_time_run = ep_durations[tmp.ep_id_node] + tmp.onnx_time_run = duration yielded_nodes += 1 if tmp.err_abs is not None: max_abs = max(max_abs, tmp.err_abs) @@ -744,8 +766,11 @@ def _loop_cmp( ) ref = run_cls(node, **run_cls_kwargs) feeds = {k: onnx_results[k] for k in node.input} + begin = time.perf_counter() res = ref.run(None, feeds) # type: ignore[attr-defined] - node_output = [o for o in node.output if o] + duration = time.perf_counter() - begin + list_node_output = list(node.output) + node_output = [o for o in list_node_output if o] for o, r in zip(node_output, res): tmp = _loop_cmp( mapping_onnx_to_torch, @@ -760,8 +785,17 @@ def _loop_cmp( i_onnx, ) if tmp is not None: - tmp.ep_target = str(ep_graph_nodes[tmp.ep_id_node].target) + if tmp.ep_name in name_to_ep_node: + tmp.ep_id_node = name_to_ep_node[tmp.ep_name] + tmp.ep_target = str(ep_graph_nodes[tmp.ep_id_node].target) + tmp.ep_time_run = ep_durations[tmp.ep_id_node] + else: + tmp.ep_id_node = None + tmp.ep_target = None + tmp.ep_name = None tmp.onnx_op_type = onx.graph.node[tmp.onnx_id_node].op_type + tmp.onnx_id_output = list_node_output.index(o) + tmp.onnx_time_run = duration yielded_nodes += 1 if tmp.err_abs is not None: max_abs = max(max_abs, tmp.err_abs)