Skip to content

Commit e8dcd00

Browse files
authored
Fix empty outputs for OnnxruntimeEvaluator (#305)
* Fix empty outputs for OnnxruntimeEvaluator * fix
1 parent 5d397f5 commit e8dcd00

File tree

3 files changed

+39
-10
lines changed

3 files changed

+39
-10
lines changed

_unittests/ut_reference/test_onnxruntime_evaluator.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,31 @@ def test_skip_layer_normalization(self):
259259
got = rt.run(None, feeds)
260260
self.assertEqualAny(expected, got, atol=1e-4)
261261

262+
@hide_stdout()
263+
def test_skip_simplified_layer_normalization(self):
264+
node = oh.make_node(
265+
"SkipSimplifiedLayerNormalization",
266+
["x", "skip", "beta", "gamma"],
267+
["Z", "", "", "bias"],
268+
epsilon=1.0e-5,
269+
domain="com.microsoft",
270+
)
271+
feeds = dict(
272+
x=self._range(2, 3, 8),
273+
skip=self._range(2, 3, 8, bias=3),
274+
beta=self._range(8, bias=1),
275+
gamma=self._range(8, bias=2),
276+
)
277+
rt = OnnxruntimeEvaluator(node, verbose=10, opsets={"": 22})
278+
got = rt.run(None, feeds)
279+
self.assertEqual(len(got), 2)
280+
self.assertIsInstance(got[0], np.ndarray)
281+
self.assertIsInstance(got[1], np.ndarray)
282+
self.assertEqual(got[0].shape, feeds["x"].shape)
283+
self.assertEqual(got[0].dtype, feeds["x"].dtype)
284+
self.assertEqual(got[1].shape, feeds["x"].shape)
285+
self.assertEqual(got[1].dtype, feeds["x"].dtype)
286+
262287

263288
if __name__ == "__main__":
264289
unittest.main(verbosity=2)

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -278,9 +278,11 @@ def run(
278278
outputs = self._run_local(node, inputs, results)
279279
else:
280280
outputs = self._run(node, inputs, results)
281-
for name, value in zip(node.output, outputs):
282-
if name == "":
283-
continue
281+
node_output = [o for o in node.output if o]
282+
assert len(node_output) == len(
283+
outputs
284+
), f"Length mismatch between node output={node.output} and outputs={outputs}"
285+
for name, value in zip(node_output, outputs):
284286
self._log(2, " + %s: %s", name, value) # type: ignore[arg-type]
285287
assert isinstance(name, str), f"unexpected type for name {type(name)}"
286288
results[name] = value
@@ -384,6 +386,11 @@ def _make_model_proto(
384386
onx = shi.infer_shapes(onx)
385387
return onx
386388

389+
def _make_model_outputs(
390+
self, node: NodeProto, inputs: List[ValueInfoProto]
391+
) -> Tuple[List[NodeProto], List[ValueInfoProto]]:
392+
return [], [oh.make_value_info(o, TypeProto()) for o in node.output if o]
393+
387394
@classmethod
388395
def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]:
389396
"""
@@ -434,6 +441,7 @@ def _get_sess(
434441
node.output[0], dtype_to_tensor_dtype(cst.dtype), cst.shape
435442
)
436443
]
444+
prenodes = [] # type: ignore[var-annotated]
437445
else:
438446
unique_names = set()
439447
vinputs = []
@@ -447,9 +455,9 @@ def _get_sess(
447455
vinputs.append(value)
448456

449457
# no need to run shape inference
450-
voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output]
458+
prenodes, voutputs = self._make_model_outputs(node, vinputs)
451459

452-
onx = self._make_model_proto([node], vinputs, voutputs)
460+
onx = self._make_model_proto([*prenodes, node], vinputs, voutputs)
453461
if node.op_type in {"Shape", "Size"}:
454462
on_cpu = True
455463

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -492,11 +492,7 @@ def _loop_cmp(
492492
f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}"
493493
)
494494
else:
495-
print(
496-
f"[run_align-dx] discrepancies "
497-
f"{string_diff(d, with_shape=True, with_device=True)} - "
498-
f"[{to}/{o}]"
499-
)
495+
print(f"[run_align-dx] discrepancies {string_diff(d)} - [{to}/{o}]")
500496
return (i, i_onnx, o, to, string_type(torch_results[to], **str_kws), d)
501497
return None
502498

0 commit comments

Comments
 (0)