|
1 | 1 | import unittest |
| 2 | +from typing import Optional |
2 | 3 | import numpy as np |
3 | 4 | import onnx |
4 | 5 | import onnx.helper as oh |
5 | 6 | import torch |
6 | 7 | import onnxruntime |
7 | | -from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout |
| 8 | +from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, ignore_warnings |
8 | 9 | from onnx_diagnostic.helpers.onnx_helper import from_array_extended |
9 | 10 | from onnx_diagnostic.reference import ( |
10 | 11 | OnnxruntimeEvaluator, |
|
22 | 23 |
|
23 | 24 |
|
24 | 25 | class TestOnnxruntimeEvaluator(ExtTestCase): |
| 26 | + def _range(self, *shape, bias: Optional[float] = None): |
| 27 | + n = np.prod(shape) |
| 28 | + x = np.arange(n).astype(np.float32) / n |
| 29 | + if bias: |
| 30 | + x = x + bias |
| 31 | + return x.reshape(tuple(shape)).astype(np.float32) |
| 32 | + |
| 33 | + @ignore_warnings(FutureWarning) |
25 | 34 | def test_ort_eval_scan_cdist_add(self): |
26 | 35 |
|
27 | 36 | def dist(unused: torch.Tensor, x: torch.Tensor, samex: torch.Tensor): |
@@ -69,6 +78,7 @@ def forward(self, x): |
69 | 78 | got = orte.run(None, {name: x.numpy()})[0] |
70 | 79 | self.assertEqualArray(expected, got) |
71 | 80 |
|
| 81 | + @ignore_warnings((UserWarning, FutureWarning)) |
72 | 82 | def test_ort_eval_cond(self): |
73 | 83 | import torch |
74 | 84 |
|
@@ -180,6 +190,7 @@ def test_constant_bool_input(self): |
180 | 190 | self.assertEqual(got.dtype, torch.bool) |
181 | 191 | self.assertEqual(got[0], True) |
182 | 192 |
|
| 193 | + @hide_stdout() |
183 | 194 | def test_ort_eval_loop(self): |
184 | 195 | model = torch.nn.EmbeddingBag(num_embeddings=49157, embedding_dim=32, mode="sum") |
185 | 196 | a = torch.tensor([[39906, 39906]]).long() |
@@ -226,6 +237,28 @@ def test_report_results_comparison_ort(self): |
226 | 237 | self.assertLess(d[(0, "nx"), "r_cos"], 1e-6) |
227 | 238 | self.assertLess(d[(2, "u"), "r_exp"], 1e-6) |
228 | 239 |
|
| 240 | + @hide_stdout() |
| 241 | + def test_skip_layer_normalization(self): |
| 242 | + node = oh.make_node( |
| 243 | + "SkipLayerNormalization", |
| 244 | + ["x", "skip", "beta", "gamma", "bias"], |
| 245 | + ["Z"], |
| 246 | + epsilon=1.0e-5, |
| 247 | + domain="com.microsoft", |
| 248 | + ) |
| 249 | + feeds = dict( |
| 250 | + x=self._range(2, 3, 8), |
| 251 | + skip=self._range(2, 3, 8, bias=3), |
| 252 | + beta=self._range(8, bias=1), |
| 253 | + gamma=self._range(8, bias=2), |
| 254 | + bias=self._range(8, bias=0.1), |
| 255 | + ) |
| 256 | + ref = ExtendedReferenceEvaluator(node) |
| 257 | + expected = ref.run(None, feeds) |
| 258 | + rt = OnnxruntimeEvaluator(node, verbose=10, opsets={"": 22}) |
| 259 | + got = rt.run(None, feeds) |
| 260 | + self.assertEqualAny(expected, got, atol=1e-4) |
| 261 | + |
229 | 262 |
|
230 | 263 | if __name__ == "__main__": |
231 | 264 | unittest.main(verbosity=2) |
0 commit comments