11from dataclasses import dataclass
2- from typing import Any , Dict , List , Iterator , Optional , Tuple
2+ from typing import Any , Dict , List , Iterator , Optional , Tuple , Union
33from enum import IntEnum
44import numpy as np
55from onnx import ModelProto , TensorProto , ValueInfoProto
@@ -77,6 +77,12 @@ def make_summary(value: Any, length: int = 4, modulo: int = 26) -> str:
7777 :param module: discretization parameter
7878 :return: short string
7979 """
80+ if isinstance (value , np .float32 ):
81+ # This should not happen.
82+ value = np .array (value )
83+ assert isinstance (
84+ value , np .ndarray
85+ ), f"Unexpected type { type (value )} for value, it must be a numpy array."
8086 value4 = np .zeros (length , dtype = np .float64 )
8187 if value .size <= length :
8288 value4 [: value .size ] = value .flatten ().astype (np .float64 )
@@ -170,6 +176,9 @@ def enumerate_results(
170176 outputs = node .run (* inputs , ** linked_attributes )
171177 except Exception :
172178 if raise_exc :
179+ # ExtendedReferenceEvaluator(self.onnx_model, verbose=10).run(
180+ # None, feed_inputs
181+ # )
173182 raise
174183 yield_output = False
175184 break
@@ -286,12 +295,12 @@ def distance_sequence(
286295 :param s2: second sequence
287296 :return: distance and alignment
288297 """
289- delay = self .max_lag
298+ delay = max ( self .max_lag , abs ( len ( s2 ) - len ( s1 )) + 1 )
290299 distance = {(- 1 , - 1 ): 0 }
291300 predecessor = {(- 1 , - 1 ): None }
292301 for i in range (len (s1 )):
293302 for j in range (max (0 , i - delay ), min (len (s2 ), i + delay )):
294- best = 1e100
303+ best = distance . get (( i , j ), 1e100 )
295304 pred = None
296305 ki , kj = i - 1 , j - 1
297306 if (ki , kj ) in distance :
@@ -418,7 +427,7 @@ def generate_inputs(model: ModelProto) -> List[np.ndarray]:
418427def compare_onnx_execution (
419428 model1 : ModelProto ,
420429 model2 : ModelProto ,
421- inputs : Optional [List [Any ]] = None ,
430+ inputs : Optional [Union [ List [Any ], Tuple [ Dict [ str , Any ]] ]] = None ,
422431 verbose : int = 0 ,
423432 raise_exc : bool = True ,
424433) -> Tuple [List [ResultExecution ], List [ResultExecution ], List [Tuple [int , int ]]]:
@@ -430,7 +439,8 @@ def compare_onnx_execution(
430439
431440 :param model1: first model
432441 :param model2: second model
433- :param inputs: inputs to use
442+ :param inputs: inputs to use, a list of inputs if both models have
443+ the same number of inputs or two dictionaries, one for each model
434444 :param verbose: verbosity
435445 :param raise_exc: raise exception if the execution fails or stop at the error
436446 :return: four results, a sequence of results for the first model and the second model,
@@ -440,8 +450,14 @@ def compare_onnx_execution(
440450 print ("[compare_onnx_execution] generate inputs" )
441451 if inputs is None :
442452 inputs = generate_inputs (model1 )
443- feeds1 = {i .name : v for i , v in zip (model1 .graph .input , inputs )}
444- feeds2 = {i .name : v for i , v in zip (model2 .graph .input , inputs )}
453+ if isinstance (inputs , tuple ):
454+ assert len (inputs ) == 2 , f"Unexpected number { len (inputs )} of inputs."
455+ feeds1 , feeds2 = inputs
456+ else :
457+ feeds1 = {i .name : v for i , v in zip (model1 .graph .input , inputs )}
458+ feeds2 = {i .name : v for i , v in zip (model2 .graph .input , inputs )}
459+ assert isinstance (feeds1 , dict ), f"Unexpected type { type (feeds1 )} for inputs"
460+ assert isinstance (feeds2 , dict ), f"Unexpected type { type (feeds2 )} for inputs"
445461 if verbose :
446462 print (f"[compare_onnx_execution] got { len (inputs )} inputs" )
447463 print ("[compare_onnx_execution] execute first model" )
0 commit comments