Skip to content

Commit 3ef5ce7

Browse files
committed
Add numba impl for CheckAndRaise
1 parent 0933d20 commit 3ef5ce7

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Unique,
2020
UnravelIndex,
2121
)
22+
from aesara.raise_op import CheckAndRaise
2223

2324

2425
@numba_funcify.register(Bartlett)
@@ -372,3 +373,18 @@ def broadcast_to(x, *shape):
372373
return np.broadcast_to(x, scalars_shape)
373374

374375
return broadcast_to
376+
377+
378+
@numba_funcify.register(CheckAndRaise)
379+
def numba_funcify_CheckAndRaise(op, node, **kwargs):
380+
error = op.exc_type
381+
msg = op.msg
382+
383+
@numba_basic.numba_njit
384+
def check_and_raise(x, *conditions):
385+
for cond in conditions:
386+
if not cond:
387+
raise error(msg)
388+
return x
389+
390+
return check_and_raise

0 commit comments

Comments
 (0)