@@ -118,6 +118,7 @@ def enumerate_results(
118118 self ,
119119 output_names : Optional [List [str ]] = None ,
120120 feed_inputs : Optional [Dict [str , Any ]] = None ,
121+ raise_exc : bool = True ,
121122 ) -> Iterator [Tuple [ResultType , str , Any ]]:
122123 """
123124 Executes the onnx model and enumerate all the intermediate results.
@@ -148,6 +149,7 @@ def enumerate_results(
148149 yield ResultType .INPUT , k , v , None
149150
150151 # step 2: execute nodes
152+ yield_output = True
151153 for node in self .evaluator .rt_nodes_ :
152154 for i in node .input :
153155 if i not in results :
@@ -160,39 +162,48 @@ def enumerate_results(
160162 linked_attributes = {}
161163 if node .has_linked_attribute and attributes :
162164 linked_attributes ["linked_attributes" ] = attributes
163- if node .need_context ():
164- outputs = node .run (* inputs , context = results , ** linked_attributes )
165- else :
166- outputs = node .run (* inputs , ** linked_attributes )
165+
166+ try :
167+ if node .need_context ():
168+ outputs = node .run (* inputs , context = results , ** linked_attributes )
169+ else :
170+ outputs = node .run (* inputs , ** linked_attributes )
171+ except Exception :
172+ if raise_exc :
173+ raise
174+ yield_output = False
175+ break
176+
167177 for name , value in zip (node .output , outputs ):
168178 yield ResultType .RESULT , name , value , node .op_type
169179 results [name ] = value
170180
171181 # step 3: outputs
172- for name in output_names :
173- if name not in results :
174- raise RuntimeError (
175- f"Unable to find output name { name !r} in { sorted (results )} , proto is\n { self .proto_ } "
176- )
177- yield ResultType .OUTPUT , name , results [name ], None
182+ if yield_output :
183+ for name in output_names :
184+ if name not in results :
185+ raise RuntimeError (
186+ f"Unable to find output name { name !r} in { sorted (results )} , proto is\n { self .proto_ } "
187+ )
188+ yield ResultType .OUTPUT , name , results [name ], None
178189
179190 def enumerate_summarized (
180191 self ,
181192 output_names : Optional [List [str ]] = None ,
182193 feed_inputs : Optional [Dict [str , Any ]] = None ,
194+ raise_exc : bool = True ,
183195 ) -> Iterator [ResultExecution ]:
184196 """
185197 Executes the onnx model and enumerate intermediate results without their names.
186198
187- Args:
188- output_names: requested outputs by names, None for all
189- feed_inputs: dictionary `{ input name: input value }`
190-
191- Returns:
192- iterator on tuple(result kind, node.type, dtype, shape, value, result name)
199+ :param output_names: requested outputs by names, None for all
200+ :param feed_inputs: dictionary `{ input name: input value }`
201+ :param raise_exc: raises an exception if the execution fails or stop
202+ where it is
203+ :return: iterator on ResultExecution
193204 """
194205 for kind , name , value , op_type in self .enumerate_results (
195- output_names , feed_inputs
206+ output_names , feed_inputs , raise_exc = raise_exc
196207 ):
197208 summary = make_summary (value )
198209 yield ResultExecution (
@@ -328,6 +339,7 @@ def to_str(
328339 """
329340 rows = []
330341 last = - 1 , - 1
342+ row_index = 1
331343 for i , j in alignment :
332344 assert i < len (s1 ), f"Unexpected value i={ i } >= len(s1)={ len (s1 )} "
333345 assert j < len (s2 ), f"Unexpected value i={ j } >= len(s2)={ len (s2 )} "
@@ -338,20 +350,18 @@ def to_str(
338350 d2 = s2 [j ]
339351 d = self .distance_pair (d1 , d2 )
340352 symbol = "=" if d == 0 else "~"
341- rows .append (
342- f"{ symbol } | { _align (str (d1 ), column_size )} | { _align (str (d2 ), column_size )} "
343- )
353+ line = f"{ symbol } | { _align (str (d1 ), column_size )} | { _align (str (d2 ), column_size )} "
344354 elif i == last [0 ]:
345355 d2 = s2 [j ]
346- rows . append (
356+ line = (
347357 f"+ | { _align ('' , column_size )} | { _align (str (d2 ), column_size )} "
348358 )
349359 else :
350360 d1 = s1 [i ]
351- rows .append (
352- f"- | { _align (str (d1 ), column_size )} | { _align ('' , column_size )} "
353- )
361+ line = f"- | { _align (str (d1 ), column_size )} | { _align ('' , column_size )} "
362+ rows .append (f"{ row_index : 3d} { line } " )
354363 last = i , j
364+ row_index += 1
355365 return "\n " .join (rows )
356366
357367
@@ -410,6 +420,7 @@ def compare_onnx_execution(
410420 model2 : ModelProto ,
411421 inputs : Optional [List [Any ]] = None ,
412422 verbose : int = 0 ,
423+ raise_exc : bool = True ,
413424) -> Tuple [List [ResultExecution ], List [ResultExecution ], List [Tuple [int , int ]]]:
414425 """
415426 Compares the execution of two onnx models.
@@ -421,6 +432,7 @@ def compare_onnx_execution(
421432 :param model2: second model
422433 :param inputs: inputs to use
423434 :param verbose: verbosity
435+ :param raise_exc: raise exception if the execution fails or stop at the error
424436 :return: four results, a sequence of results for the first model and the second model,
425437 the alignment between the two, DistanceExecution
426438 """
@@ -433,11 +445,15 @@ def compare_onnx_execution(
433445 if verbose :
434446 print (f"[compare_onnx_execution] got { len (inputs )} inputs" )
435447 print ("[compare_onnx_execution] execute first model" )
436- res1 = list (YieldEvaluator (model1 ).enumerate_summarized (None , feeds1 ))
448+ res1 = list (
449+ YieldEvaluator (model1 ).enumerate_summarized (None , feeds1 , raise_exc = raise_exc )
450+ )
437451 if verbose :
438452 print (f"[compare_onnx_execution] got { len (res1 )} results" )
439453 print ("[compare_onnx_execution] execute second model" )
440- res2 = list (YieldEvaluator (model2 ).enumerate_summarized (None , feeds2 ))
454+ res2 = list (
455+ YieldEvaluator (model2 ).enumerate_summarized (None , feeds2 , raise_exc = raise_exc )
456+ )
441457 if verbose :
442458 print (f"[compare_onnx_execution] got { len (res2 )} results" )
443459 print ("[compare_onnx_execution] compute edit distance" )
0 commit comments