@@ -57,6 +57,7 @@ class ResultExecution:
5757 summary : str
5858 op_type : str
5959 name : str
60+ value : Optional [Any ] = None
6061
6162 def __len__ (self ) -> int :
6263 return 6
@@ -122,9 +123,11 @@ def make_summary(value: Any, length: int = 4, modulo: int = 26) -> str:
122123 else :
123124 value2 = value .flatten ().astype (np .float64 )
124125 value4 = value2 .reshape ((4 , - 1 )).sum (axis = 1 )
125- value4i = value4 .astype (np .int64 ) % modulo
126- s = "" .join ([chr (65 + i ) for i in value4i ])
127- return s
126+ value4 = np .where (np .abs (value4 ) < 1e10 , value4 , np .nan )
127+ s = []
128+ for v in value4 :
129+ s .append ("?" if np .isnan (v ) else (chr (65 + int (v ) % modulo )))
130+ return "" .join (s )
128131
129132
130133class YieldEvaluator :
@@ -228,6 +231,7 @@ def enumerate_summarized(
228231 output_names : Optional [List [str ]] = None ,
229232 feed_inputs : Optional [Dict [str , Any ]] = None ,
230233 raise_exc : bool = True ,
234+ keep_tensor : bool = False ,
231235 ) -> Iterator [ResultExecution ]:
232236 """
233237 Executes the onnx model and enumerate intermediate results without their names.
@@ -236,17 +240,40 @@ def enumerate_summarized(
236240 :param feed_inputs: dictionary `{ input name: input value }`
237241 :param raise_exc: raises an exception if the execution fails or stop
238242 where it is
243+ :param keep_tensor:keep the tensor in order to compute precise distances
239244 :return: iterator on ResultExecution
240245 """
241246 for kind , name , value , op_type in self .enumerate_results (
242247 output_names , feed_inputs , raise_exc = raise_exc
243248 ):
244249 summary = make_summary (value )
245250 yield ResultExecution (
246- kind , value .dtype , value .shape , summary , op_type , name
251+ kind ,
252+ value .dtype ,
253+ value .shape ,
254+ summary ,
255+ op_type ,
256+ name ,
257+ value = value if keep_tensor else None ,
247258 )
248259
249260
261+ def discrepancies (
262+ expected : np .ndarray , value : np .ndarray , eps : float = 1e-7
263+ ) -> Dict [str , float ]:
264+ """
265+ Computes absolute error and relative error between two matrices.
266+ """
267+ assert (
268+ expected .size == value .size
269+ ), f"Incompatible shapes v1.shape={ expected .shape } , v2.shape={ value .shape } "
270+ expected = expected .ravel ().astype (np .float32 )
271+ value = value .ravel ().astype (np .float32 )
272+ diff = np .abs (expected - value )
273+ rel = diff / (np .abs (expected ) + eps )
274+ return dict (aerr = float (diff .max ()), rerr = float (rel .max ()))
275+
276+
250277class DistanceExecution :
251278 """
252279 Computes a distance between two results.
@@ -403,6 +430,14 @@ def to_str(
403430 d = self .distance_pair (d1 , d2 )
404431 symbol = "=" if d == 0 else "~"
405432 line = f"{ symbol } | { _align (str (d1 ), column_size )} | { _align (str (d2 ), column_size )} "
433+ if (
434+ d1 .value is not None
435+ and d2 .value is not None
436+ and d1 .value .size == d2 .value .size
437+ ):
438+ disc = discrepancies (d1 .value , d2 .value )
439+ a , r = disc ["aerr" ], disc ["rerr" ]
440+ line += f" | a={ a :.3f} r={ r :.3f} "
406441 elif i == last [0 ]:
407442 d2 = s2 [j ]
408443 line = (
@@ -551,6 +586,7 @@ def compare_onnx_execution(
551586 verbose : int = 0 ,
552587 raise_exc : bool = True ,
553588 mode : str = "execute" ,
589+ keep_tensor : bool = False ,
554590) -> Tuple [List [ResultExecution ], List [ResultExecution ], List [Tuple [int , int ]]]:
555591 """
556592 Compares the execution of two onnx models.
@@ -566,6 +602,7 @@ def compare_onnx_execution(
566602 :param raise_exc: raise exception if the execution fails or stop at the error
567603 :param mode: the model should be executed but the function can be executed
568604 but the comparison may append on nodes only
605+ :param keep_tensor: keeps the tensor in order to compute a precise distance
569606 :return: four results, a sequence of results for the first model and the second model,
570607 the alignment between the two, DistanceExecution
571608 """
@@ -589,15 +626,15 @@ def compare_onnx_execution(
589626 print ("[compare_onnx_execution] execute first model" )
590627 res1 = list (
591628 YieldEvaluator (model1 ).enumerate_summarized (
592- None , feeds1 , raise_exc = raise_exc
629+ None , feeds1 , raise_exc = raise_exc , keep_tensor = keep_tensor
593630 )
594631 )
595632 if verbose :
596633 print (f"[compare_onnx_execution] got { len (res1 )} results" )
597634 print ("[compare_onnx_execution] execute second model" )
598635 res2 = list (
599636 YieldEvaluator (model2 ).enumerate_summarized (
600- None , feeds2 , raise_exc = raise_exc
637+ None , feeds2 , raise_exc = raise_exc , keep_tensor = keep_tensor
601638 )
602639 )
603640 elif mode == "nodes" :
0 commit comments