|
31 | 31 |
|
32 | 32 |
|
33 | 33 | @pytest.mark.parametrize( |
34 | | - "inputs, input_vals, output_fn, exc", |
| 34 | + "inputs, input_vals, output_fn", |
35 | 35 | [ |
36 | 36 | ( |
37 | 37 | [pt.vector()], |
38 | 38 | [rng.uniform(size=100).astype(config.floatX)], |
39 | 39 | lambda x: pt.gammaln(x), |
40 | | - None, |
41 | 40 | ), |
42 | 41 | ( |
43 | 42 | [pt.vector()], |
44 | 43 | [rng.standard_normal(100).astype(config.floatX)], |
45 | 44 | lambda x: pt.sigmoid(x), |
46 | | - None, |
47 | 45 | ), |
48 | 46 | ( |
49 | 47 | [pt.vector()], |
50 | 48 | [rng.standard_normal(100).astype(config.floatX)], |
51 | 49 | lambda x: pt.log1mexp(x), |
52 | | - None, |
53 | 50 | ), |
54 | 51 | ( |
55 | 52 | [pt.vector()], |
56 | 53 | [rng.standard_normal(100).astype(config.floatX)], |
57 | 54 | lambda x: pt.erf(x), |
58 | | - None, |
59 | 55 | ), |
60 | 56 | ( |
61 | 57 | [pt.vector()], |
62 | 58 | [rng.standard_normal(100).astype(config.floatX)], |
63 | 59 | lambda x: pt.erfc(x), |
64 | | - None, |
65 | 60 | ), |
66 | 61 | ( |
67 | 62 | [pt.vector()], |
68 | 63 | [rng.standard_normal(100).astype(config.floatX)], |
69 | 64 | lambda x: pt.erfcx(x), |
70 | | - None, |
71 | 65 | ), |
72 | 66 | ( |
73 | 67 | [pt.vector() for i in range(4)], |
74 | 68 | [rng.standard_normal(100).astype(config.floatX) for i in range(4)], |
75 | 69 | lambda x, y, x1, y1: (x + y) * (x1 + y1) * y, |
76 | | - None, |
77 | 70 | ), |
78 | 71 | ( |
79 | 72 | [pt.matrix(), pt.scalar()], |
80 | 73 | [rng.normal(size=(2, 2)).astype(config.floatX), 0.0], |
81 | 74 | lambda a, b: pt.switch(a, b, a), |
82 | | - None, |
83 | 75 | ), |
84 | 76 | ( |
85 | 77 | [pt.scalar(), pt.scalar()], |
|
88 | 80 | np.array(1.0, dtype=config.floatX), |
89 | 81 | ], |
90 | 82 | lambda x, y: pti.add_inplace(deep_copy_op(x), deep_copy_op(y)), |
91 | | - None, |
92 | 83 | ), |
93 | 84 | ( |
94 | 85 | [pt.vector(), pt.vector()], |
|
97 | 88 | rng.standard_normal(100).astype(config.floatX), |
98 | 89 | ], |
99 | 90 | lambda x, y: pti.add_inplace(deep_copy_op(x), deep_copy_op(y)), |
100 | | - None, |
101 | 91 | ), |
102 | 92 | ( |
103 | 93 | [pt.vector(), pt.vector()], |
|
106 | 96 | rng.standard_normal(100).astype(config.floatX), |
107 | 97 | ], |
108 | 98 | lambda x, y: scalar_my_multi_out(x, y), |
109 | | - None, |
110 | 99 | ), |
111 | 100 | ], |
| 101 | + ids=[ |
| 102 | + "gammaln", |
| 103 | + "sigmoid", |
| 104 | + "log1mexp", |
| 105 | + "erf", |
| 106 | + "erfc", |
| 107 | + "erfcx", |
| 108 | + "complex_arithmetic", |
| 109 | + "switch", |
| 110 | + "add_inplace_scalar", |
| 111 | + "add_inplace_vector", |
| 112 | + "scalar_multi_out", |
| 113 | + ], |
112 | 114 | ) |
113 | | -def test_Elemwise(inputs, input_vals, output_fn, exc): |
| 115 | +def test_Elemwise(inputs, input_vals, output_fn): |
114 | 116 | outputs = output_fn(*inputs) |
115 | 117 |
|
116 | | - cm = contextlib.suppress() if exc is None else pytest.raises(exc) |
117 | | - with cm: |
118 | | - compare_numba_and_py( |
119 | | - inputs, |
120 | | - outputs, |
121 | | - input_vals, |
122 | | - ) |
| 118 | + compare_numba_and_py( |
| 119 | + inputs, |
| 120 | + outputs, |
| 121 | + input_vals, |
| 122 | + ) |
123 | 123 |
|
124 | 124 |
|
125 | 125 | @pytest.mark.xfail(reason="Logic had to be reversed due to surprising segfaults") |
|
0 commit comments