Skip to content

Commit 1b92c77

Browse files
committed
Remove stale Assert tests
These tests were covering things that don't exist anymore. params in python perform method of Ops, or misbehavior of an Op not respecting the signature
1 parent c722e9f commit 1b92c77

File tree

3 files changed

+19
-50
lines changed

3 files changed

+19
-50
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -233,37 +233,23 @@ def generate_fallback_impl(op, node, storage_map=None, **kwargs):
233233
node.dprint(depth=5, print_type=True)
234234

235235
n_outputs = len(node.outputs)
236+
single_out = n_outputs == 1
236237

237-
if n_outputs > 1:
238-
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
239-
else:
238+
if single_out:
240239
ret_sig = get_numba_type(node.outputs[0].type)
241-
242-
output_types = tuple(out.type for out in node.outputs)
243-
244-
def py_perform(inputs):
245-
outputs = [[None] for i in range(n_outputs)]
246-
op.perform(node, inputs, outputs)
247-
return outputs
248-
249-
if n_outputs == 1:
250-
251-
def py_perform_return(inputs):
252-
return output_types[0].filter(py_perform(inputs)[0][0])
253-
254240
else:
241+
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
255242

256-
def py_perform_return(inputs):
257-
# zip strict not specified because we are in a hot loop
258-
return tuple(
259-
out_type.filter(out[0])
260-
for out_type, out in zip(output_types, py_perform(inputs))
261-
)
243+
def py_perform(inputs):
244+
output_storage = [[None] for _i in range(n_outputs)]
245+
op.perform(node, inputs, output_storage)
246+
outputs = tuple(o[0] for o in output_storage)
247+
return outputs[0] if single_out else outputs
262248

263249
@numba_njit
264250
def perform(*inputs):
265251
with numba.objmode(ret=ret_sig):
266-
ret = py_perform_return(inputs)
252+
ret = py_perform(inputs)
267253
return ret
268254

269255
return perform

tests/link/numba/test_basic.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from pytensor.ifelse import ifelse
2727
from pytensor.link.numba.dispatch import basic as numba_basic
2828
from pytensor.link.numba.linker import NumbaLinker
29-
from pytensor.raise_op import assert_op
3029
from pytensor.scalar.basic import ScalarOp, as_scalar
3130
from pytensor.tensor.elemwise import Elemwise
3231

@@ -372,32 +371,6 @@ def test_perform(inputs, op, exc):
372371
)
373372

374373

375-
def test_perform_params():
376-
"""This tests for `Op.perform` implementations that require the `params` arguments."""
377-
378-
x = pt.vector(shape=(2,))
379-
x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
380-
381-
out = assert_op(x, np.array(True))
382-
383-
compare_numba_and_py([x], out, [x_test_value])
384-
385-
386-
def test_perform_type_convert():
387-
"""This tests the use of `Type.filter` in `objmode`.
388-
389-
The `Op.perform` takes a single input that it returns as-is, but it gets a
390-
native scalar and it's supposed to return an `np.ndarray`.
391-
"""
392-
393-
x = pt.vector()
394-
x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
395-
396-
out = assert_op(x.sum(), np.array(True))
397-
398-
compare_numba_and_py([x], out, [x_test_value])
399-
400-
401374
def test_shared():
402375
a = shared(np.array([1, 2, 3], dtype=config.floatX))
403376

tests/link/numba/test_extra_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pytensor.tensor as pt
77
from pytensor import config
8+
from pytensor.raise_op import assert_op
89
from pytensor.tensor import extra_ops
910
from tests.link.numba.test_basic import compare_numba_and_py
1011

@@ -372,3 +373,12 @@ def test_Searchsorted(a, v, side, sorter, exc):
372373
g,
373374
[test_a, test_v] if sorter is None else [test_a, test_v, test_sorter],
374375
)
376+
377+
378+
def test_check_and_raise():
379+
x = pt.vector()
380+
x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
381+
382+
out = assert_op(x.sum(), np.array(True))
383+
384+
compare_numba_and_py([x], out, [x_test_value])

0 commit comments

Comments
 (0)