1010)
1111from onnx_diagnostic .reference import ExtendedReferenceEvaluator , OnnxruntimeEvaluator
1212from onnx_diagnostic .torch_export_patches .patch_inputs import use_dyn_not_str
13- from onnx_diagnostic .torch_onnx .sbs import run_aligned , post_process_run_aligned_obs
13+ from onnx_diagnostic .torch_onnx .sbs import run_aligned , RunAlignedRecord
1414from onnx_diagnostic .export .api import to_onnx
1515
1616
@@ -21,6 +21,24 @@ def setUpClass(cls):
2121
2222 cls .torch = torch
2323
24+ def test_run_aligned_record (self ):
25+ r = RunAlignedRecord (
26+ ep_id_node = - 1 ,
27+ onnx_id_node = - 1 ,
28+ ep_name = "A" ,
29+ onnx_name = "B" ,
30+ ep_target = "C" ,
31+ onnx_op_type = "D" ,
32+ shape_type = "E" ,
33+ err_abs = 0.1 ,
34+ err_rel = 0.2 ,
35+ err_dev = 0.3 ,
36+ err_nan = 0.4 ,
37+ )
38+ sr = str (r )
39+ self .assertIn ("RunAlignedRecord(" , sr )
40+ self .assertIn ("shape_type='E'" , sr )
41+
2442 @hide_stdout ()
2543 @unittest .skipIf (to_onnx is None , "to_onnx not installed" )
2644 @ignore_errors (OSError ) # connectivity issues
@@ -48,7 +66,7 @@ def forward(self, x):
4866 run_cls = ExtendedReferenceEvaluator ,
4967 atol = 1e-5 ,
5068 rtol = 1e-5 ,
51- verbose = 1 ,
69+ verbose = 10 ,
5270 ),
5371 )
5472 self .assertEqual (len (results ), 7 )
@@ -83,7 +101,7 @@ def forward(self, x):
83101 run_cls = ExtendedReferenceEvaluator ,
84102 atol = 1e-5 ,
85103 rtol = 1e-5 ,
86- verbose = 1 ,
104+ verbose = 10 ,
87105 ),
88106 )
89107 self .assertEqual (len (results ), 6 )
@@ -115,7 +133,7 @@ def forward(self, x):
115133 run_cls = ExtendedReferenceEvaluator ,
116134 atol = 1e-5 ,
117135 rtol = 1e-5 ,
118- verbose = 1 ,
136+ verbose = 10 ,
119137 ),
120138 )
121139 self .assertEqual (len (results ), 6 )
@@ -182,7 +200,6 @@ def forward(self, x):
182200 ),
183201 )
184202 self .assertEqual (len (results ), 8 )
185- self .clean_dump ()
186203
187204 @hide_stdout ()
188205 @ignore_warnings ((DeprecationWarning , FutureWarning , UserWarning ))
@@ -285,7 +302,10 @@ def forward(self, x):
285302 ),
286303 )
287304 self .assertEqual (len (results ), 14 )
288- self .assertEqual ([r [- 1 ].get ("dev" , 0 ) for r in results ], [0 ] * 14 )
305+ self .assertEqual (
306+ [r .err_dev for r in results ],
307+ [None , None , None , None , None , None , None , None , 0 , 0 , 0 , 0 , 0 , 0 ],
308+ )
289309
290310 @hide_stdout ()
291311 @ignore_warnings ((DeprecationWarning , FutureWarning , UserWarning ))
@@ -323,25 +343,32 @@ def forward(self, x):
323343 use_tensor = True ,
324344 ),
325345 )
326- df = pandas .DataFrame (list (map ( post_process_run_aligned_obs , results ) ))
346+ df = pandas .DataFrame (list (results ))
327347 df .to_excel (self .get_dump_file ("test_sbs_model_with_weights_custom.xlsx" ))
328348 self .assertEqual (
329349 [
330350 "ep_id_node" ,
331351 "ep_name" ,
332352 "ep_target" ,
353+ "ep_time_run" ,
333354 "err_abs" ,
334355 "err_dev" ,
356+ "err_nan" ,
335357 "err_rel" ,
336358 "onnx_id_node" ,
359+ "onnx_id_output" ,
337360 "onnx_name" ,
338361 "onnx_op_type" ,
362+ "onnx_time_run" ,
339363 "shape_type" ,
340364 ],
341365 sorted (df .columns ),
342366 )
343367 self .assertEqual (len (results ), 12 )
344- self .assertEqual ([r [- 1 ].get ("dev" , 0 ) for r in results ], [0 ] * 12 )
368+ self .assertEqual (
369+ [r .err_dev for r in results ],
370+ [None , None , None , None , None , None , None , None , None , 0 , 0 , 0 ],
371+ )
345372 self .assertEqual (
346373 [- 1.0 , - 1.0 , - 1.0 , - 1.0 , - 10.0 , - 10.0 , - 10.0 , - 10.0 , - 1.0 , 0.0 , 1.0 , 2.0 ],
347374 df ["onnx_id_node" ].fillna (- 10 ).tolist (),
@@ -384,25 +411,32 @@ def forward(self, x):
384411 use_tensor = True ,
385412 ),
386413 )
387- df = pandas .DataFrame (list (map ( post_process_run_aligned_obs , results ) ))
414+ df = pandas .DataFrame (list (results ))
388415 df .to_excel (self .get_dump_file ("test_sbs_model_with_weights_dynamo.xlsx" ))
389416 self .assertEqual (
390417 [
391418 "ep_id_node" ,
392419 "ep_name" ,
393420 "ep_target" ,
421+ "ep_time_run" ,
394422 "err_abs" ,
395423 "err_dev" ,
424+ "err_nan" ,
396425 "err_rel" ,
397426 "onnx_id_node" ,
427+ "onnx_id_output" ,
398428 "onnx_name" ,
399429 "onnx_op_type" ,
430+ "onnx_time_run" ,
400431 "shape_type" ,
401432 ],
402433 sorted (df .columns ),
403434 )
404435 self .assertEqual (len (results ), 12 )
405- self .assertEqual ([r [- 1 ].get ("dev" , 0 ) for r in results ], [0 ] * 12 )
436+ self .assertEqual (
437+ [r .err_dev for r in results ],
438+ [None , None , None , None , None , None , None , None , None , 0 , 0 , 0 ],
439+ )
406440 self .assertEqual (
407441 [- 1.0 , - 1.0 , - 1.0 , - 1.0 , - 10.0 , - 10.0 , - 10.0 , - 10.0 , - 1.0 , 0.0 , 1.0 , 2.0 ],
408442 df ["onnx_id_node" ].fillna (- 10 ).tolist (),
0 commit comments