From 6c437baa5fffb8381b90c7f01c4c6aacd66bd53c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 9 Oct 2025 17:02:51 +0200 Subject: [PATCH 01/16] Remove unused numba_vectorize --- pytensor/link/numba/dispatch/basic.py | 7 ------- tests/link/numba/test_basic.py | 29 --------------------------- 2 files changed, 36 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 87b8e380d3..8a5d628954 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -87,13 +87,6 @@ def numba_njit(*args, fastmath=None, **kwargs): return numba.njit(*args, fastmath=fastmath, **kwargs) -def numba_vectorize(*args, **kwargs): - if len(args) > 0 and callable(args[0]): - return numba.vectorize(*args[1:], cache=config.numba__cache, **kwargs)(args[0]) - - return numba.vectorize(*args, cache=config.numba__cache, **kwargs) - - def get_numba_type( pytensor_type: Type, layout: str = "A", diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index fd9a48111f..8e8f3cea31 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -1,5 +1,4 @@ import contextlib -import inspect from collections.abc import Callable, Iterable from typing import TYPE_CHECKING, Any from unittest import mock @@ -151,30 +150,6 @@ def njit_noop(*args, **kwargs): else: return lambda x: x - def vectorize_noop(*args, **kwargs): - def wrap(fn): - # `numba.vectorize` allows an `out` positional argument. We need - # to account for that - sig = inspect.signature(fn) - nparams = len(sig.parameters) - - def inner_vec(*args): - if len(args) > nparams: - # An `out` argument has been specified for an in-place - # operation - out = args[-1] - out[...] = np.vectorize(fn)(*args[:nparams]) - return out - else: - return np.vectorize(fn)(*args) - - return inner_vec - - if len(args) == 1 and callable(args[0]): - return wrap(args[0], **kwargs) - else: - return wrap - def py_global_numba_func(func): if hasattr(func, "py_func"): return func.py_func @@ -182,7 +157,6 @@ def py_global_numba_func(func): mocks = [ mock.patch("numba.njit", njit_noop), - mock.patch("numba.vectorize", vectorize_noop), mock.patch( "pytensor.link.numba.dispatch.basic.global_numba_func", py_global_numba_func, @@ -191,9 +165,6 @@ def py_global_numba_func(func): "pytensor.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem ), mock.patch("pytensor.link.numba.dispatch.basic.numba_njit", njit_noop), - mock.patch( - "pytensor.link.numba.dispatch.basic.numba_vectorize", vectorize_noop - ), mock.patch( "pytensor.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x ), From 3b359fc212594938fc67b09b77b4667a6c0a352e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 14 Oct 2025 13:39:36 +0200 Subject: [PATCH 02/16] Remove numba__vectorize_target config --- pytensor/configdefaults.py | 6 ------ pytensor/configparser.py | 1 - tests/link/numba/test_basic.py | 10 ---------- 3 files changed, 17 deletions(-) diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index f99b8240ca..6763509d75 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -1099,12 +1099,6 @@ def add_scan_configvars(): def add_numba_configvars(): - config.add( - "numba__vectorize_target", - ("Default target for numba.vectorize."), - EnumStr("cpu", ["parallel", "cuda"], mutable=True), - in_c_key=False, - ) config.add( "numba__fastmath", ("If True, use Numba's fastmath mode."), diff --git a/pytensor/configparser.py b/pytensor/configparser.py index c7da71426d..d33c970ba1 100644 --- a/pytensor/configparser.py +++ b/pytensor/configparser.py @@ -157,7 +157,6 @@ class PyTensorConfigParser: scan__allow_gc: bool scan__allow_output_prealloc: bool # add_numba_configvars - numba__vectorize_target: str numba__fastmath: bool numba__cache: bool # add_caching_dir_configvars diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 8e8f3cea31..8ee035966b 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -769,16 +769,6 @@ def test_IfElse(inputs, cond_fn, true_vals, false_vals): compare_numba_and_py(inputs, out, test_values) -@pytest.mark.xfail(reason="https://github.com/numba/numba/issues/7409") -def test_config_options_parallel(): - x = pt.dvector() - - with config.change_flags(numba__vectorize_target="parallel"): - pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode) - numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] - assert numba_mul_fn.targetoptions["parallel"] is True - - def test_config_options_fastmath(): x = pt.dvector() From 6660ad70f3974b1b70d6d173dd990adb34644278 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 9 Oct 2025 17:41:56 +0200 Subject: [PATCH 03/16] Remove unused global_numba_func --- pytensor/link/numba/dispatch/basic.py | 8 -------- tests/link/numba/test_basic.py | 9 --------- 2 files changed, 17 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 8a5d628954..e63cf13141 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -42,14 +42,6 @@ from pytensor.tensor.type_other import MakeSlice, NoneConst -def global_numba_func(func): - """Use to return global numba functions in numba_funcify_*. - - This allows tests to remove the compilation using mock. - """ - return func - - def numba_njit(*args, fastmath=None, **kwargs): kwargs.setdefault("cache", config.numba__cache) kwargs.setdefault("no_cpython_wrapper", True) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 8ee035966b..fb92d1cf0c 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -150,17 +150,8 @@ def njit_noop(*args, **kwargs): else: return lambda x: x - def py_global_numba_func(func): - if hasattr(func, "py_func"): - return func.py_func - return func - mocks = [ mock.patch("numba.njit", njit_noop), - mock.patch( - "pytensor.link.numba.dispatch.basic.global_numba_func", - py_global_numba_func, - ), mock.patch( "pytensor.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem ), From 182abd0e99017e18a7429051fd37a3987b45a0c8 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 9 Oct 2025 16:58:02 +0200 Subject: [PATCH 04/16] Remove duplicated Solve dispatch --- pytensor/link/numba/dispatch/basic.py | 48 --------------------------- 1 file changed, 48 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index e63cf13141..2e745224c2 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -8,8 +8,6 @@ import numba import numba.np.unsafe.ndarray as numba_ndarray import numpy as np -import scipy -import scipy.special from llvmlite import ir from numba import types from numba.core.errors import NumbaWarning, TypingError @@ -36,7 +34,6 @@ from pytensor.tensor.blas import BatchedDot from pytensor.tensor.math import Dot from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape -from pytensor.tensor.slinalg import Solve from pytensor.tensor.sort import ArgSortOp, SortOp from pytensor.tensor.type import TensorType from pytensor.tensor.type_other import MakeSlice, NoneConst @@ -626,51 +623,6 @@ def dot_with_cast(x, y): return dot_with_cast -@numba_funcify.register(Solve) -def numba_funcify_Solve(op, node, **kwargs): - assume_a = op.assume_a - # check_finite = op.check_finite - - if assume_a != "gen": - lower = op.lower - - warnings.warn( - ( - "Numba will use object mode to allow the " - "`compute_uv` argument to `numpy.linalg.svd`." - ), - UserWarning, - ) - - ret_sig = get_numba_type(node.outputs[0].type) - - @numba_njit - def solve(a, b): - with numba.objmode(ret=ret_sig): - ret = scipy.linalg.solve_triangular( - a, - b, - lower=lower, - # check_finite=check_finite - ) - return ret - - else: - out_dtype = node.outputs[0].type.numpy_dtype - inputs_cast = int_to_float_fn(node.inputs, out_dtype) - - @numba_njit - def solve(a, b): - return np.linalg.solve( - inputs_cast(a), - inputs_cast(b), - # assume_a=assume_a, - # check_finite=check_finite, - ).astype(out_dtype) - - return solve - - @numba_funcify.register(BatchedDot) def numba_funcify_BatchedDot(op, node, **kwargs): dtype = node.outputs[0].type.numpy_dtype From 0609870eebebf3644f9d46d4e22e015281f36316 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 9 Oct 2025 17:05:25 +0200 Subject: [PATCH 05/16] Move slice dispatcher functionality to subtensor.py --- pytensor/link/numba/dispatch/basic.py | 79 +--------------------- pytensor/link/numba/dispatch/subtensor.py | 81 ++++++++++++++++++++++- 2 files changed, 82 insertions(+), 78 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 2e745224c2..8cb556c283 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -1,5 +1,3 @@ -import operator -import sys import warnings from copy import copy from functools import singledispatch @@ -8,11 +6,10 @@ import numba import numba.np.unsafe.ndarray as numba_ndarray import numpy as np -from llvmlite import ir from numba import types from numba.core.errors import NumbaWarning, TypingError from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 -from numba.extending import box, overload +from numba.extending import overload from pytensor import In, config from pytensor.compile import NUMBA @@ -36,7 +33,7 @@ from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.tensor.sort import ArgSortOp, SortOp from pytensor.tensor.type import TensorType -from pytensor.tensor.type_other import MakeSlice, NoneConst +from pytensor.tensor.type_other import NoneConst def numba_njit(*args, fastmath=None, **kwargs): @@ -149,69 +146,6 @@ def create_numba_signature( return numba.types.void(*input_types) -def slice_new(self, start, stop, step): - fnty = ir.FunctionType(self.pyobj, [self.pyobj, self.pyobj, self.pyobj]) - fn = self._get_function(fnty, name="PySlice_New") - return self.builder.call(fn, [start, stop, step]) - - -def enable_slice_boxing(): - """Enable boxing for Numba's native ``slice``s. - - TODO: this can be removed when https://github.com/numba/numba/pull/6939 is - merged and a release is made. - """ - - @box(types.SliceType) - def box_slice(typ, val, c): - """Implement boxing for ``slice`` objects in Numba. - - This makes it possible to return an Numba's internal representation of a - ``slice`` object as a proper ``slice`` to Python. - """ - start = c.builder.extract_value(val, 0) - stop = c.builder.extract_value(val, 1) - - none_val = ir.Constant(ir.IntType(64), sys.maxsize) - - start_is_none = c.builder.icmp_signed("==", start, none_val) - start = c.builder.select( - start_is_none, - c.pyapi.get_null_object(), - c.box(types.int64, start), - ) - - stop_is_none = c.builder.icmp_signed("==", stop, none_val) - stop = c.builder.select( - stop_is_none, - c.pyapi.get_null_object(), - c.box(types.int64, stop), - ) - - if typ.has_step: - step = c.builder.extract_value(val, 2) - step_is_none = c.builder.icmp_signed("==", step, none_val) - step = c.builder.select( - step_is_none, - c.pyapi.get_null_object(), - c.box(types.int64, step), - ) - else: - step = c.pyapi.get_null_object() - - slice_val = slice_new(c.pyapi, start, stop, step) - - return slice_val - - @numba.extending.overload(operator.contains) - def in_seq_empty_tuple(x, y): - if isinstance(x, types.Tuple) and not x.types: - return lambda x, y: False - - -enable_slice_boxing() - - def to_scalar(x): return np.asarray(x).item() @@ -388,15 +322,6 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs): return deepcopyop -@numba_funcify.register(MakeSlice) -def numba_funcify_MakeSlice(op, **kwargs): - @numba_njit - def makeslice(*x): - return slice(*x) - - return makeslice - - @numba_funcify.register(Shape) def numba_funcify_Shape(op, **kwargs): @numba_njit diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index e877241977..4727eaf337 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -1,4 +1,11 @@ +import operator +import sys + +import numba import numpy as np +from llvmlite import ir +from numba import types +from numba.core.pythonapi import box from pytensor.graph import Type from pytensor.link.numba.dispatch import numba_funcify @@ -14,7 +21,79 @@ IncSubtensor, Subtensor, ) -from pytensor.tensor.type_other import NoneTypeT, SliceType +from pytensor.tensor.type_other import MakeSlice, NoneTypeT, SliceType + + +def slice_new(self, start, stop, step): + fnty = ir.FunctionType(self.pyobj, [self.pyobj, self.pyobj, self.pyobj]) + fn = self._get_function(fnty, name="PySlice_New") + return self.builder.call(fn, [start, stop, step]) + + +def enable_slice_boxing(): + """Enable boxing for Numba's native ``slice``s. + + TODO: this can be removed when https://github.com/numba/numba/pull/6939 is + merged and a release is made. + """ + + @box(types.SliceType) + def box_slice(typ, val, c): + """Implement boxing for ``slice`` objects in Numba. + + This makes it possible to return an Numba's internal representation of a + ``slice`` object as a proper ``slice`` to Python. + """ + start = c.builder.extract_value(val, 0) + stop = c.builder.extract_value(val, 1) + + none_val = ir.Constant(ir.IntType(64), sys.maxsize) + + start_is_none = c.builder.icmp_signed("==", start, none_val) + start = c.builder.select( + start_is_none, + c.pyapi.get_null_object(), + c.box(types.int64, start), + ) + + stop_is_none = c.builder.icmp_signed("==", stop, none_val) + stop = c.builder.select( + stop_is_none, + c.pyapi.get_null_object(), + c.box(types.int64, stop), + ) + + if typ.has_step: + step = c.builder.extract_value(val, 2) + step_is_none = c.builder.icmp_signed("==", step, none_val) + step = c.builder.select( + step_is_none, + c.pyapi.get_null_object(), + c.box(types.int64, step), + ) + else: + step = c.pyapi.get_null_object() + + slice_val = slice_new(c.pyapi, start, stop, step) + + return slice_val + + @numba.extending.overload(operator.contains) + def in_seq_empty_tuple(x, y): + if isinstance(x, types.Tuple) and not x.types: + return lambda x, y: False + + +enable_slice_boxing() + + +@numba_funcify.register(MakeSlice) +def numba_funcify_MakeSlice(op, **kwargs): + @numba_njit + def makeslice(*x): + return slice(*x) + + return makeslice @numba_funcify.register(Subtensor) From f523fd44576567fc223ac616143e6720d5f81520 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 9 Oct 2025 17:01:50 +0200 Subject: [PATCH 06/16] Move shape Ops dispatchers to their own file --- pytensor/link/numba/dispatch/__init__.py | 1 + pytensor/link/numba/dispatch/basic.py | 73 ----------------- pytensor/link/numba/dispatch/shape.py | 78 ++++++++++++++++++ tests/link/numba/test_basic.py | 92 --------------------- tests/link/numba/test_shape.py | 100 +++++++++++++++++++++++ 5 files changed, 179 insertions(+), 165 deletions(-) create mode 100644 pytensor/link/numba/dispatch/shape.py create mode 100644 tests/link/numba/test_shape.py diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index 1fefb1d06d..1541331d31 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -9,6 +9,7 @@ import pytensor.link.numba.dispatch.random import pytensor.link.numba.dispatch.scan import pytensor.link.numba.dispatch.scalar +import pytensor.link.numba.dispatch.shape import pytensor.link.numba.dispatch.signal import pytensor.link.numba.dispatch.slinalg import pytensor.link.numba.dispatch.sparse diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 8cb556c283..04008c5ed8 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -1,10 +1,8 @@ import warnings from copy import copy from functools import singledispatch -from textwrap import dedent import numba -import numba.np.unsafe.ndarray as numba_ndarray import numpy as np from numba import types from numba.core.errors import NumbaWarning, TypingError @@ -22,7 +20,6 @@ from pytensor.ifelse import IfElse from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType from pytensor.link.utils import ( - compile_function_src, fgraph_to_python, ) from pytensor.scalar.basic import ScalarType @@ -30,10 +27,8 @@ from pytensor.tensor.basic import Nonzero from pytensor.tensor.blas import BatchedDot from pytensor.tensor.math import Dot -from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.tensor.sort import ArgSortOp, SortOp from pytensor.tensor.type import TensorType -from pytensor.tensor.type_other import NoneConst def numba_njit(*args, fastmath=None, **kwargs): @@ -322,26 +317,6 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs): return deepcopyop -@numba_funcify.register(Shape) -def numba_funcify_Shape(op, **kwargs): - @numba_njit - def shape(x): - return np.asarray(np.shape(x)) - - return shape - - -@numba_funcify.register(Shape_i) -def numba_funcify_Shape_i(op, **kwargs): - i = op.i - - @numba_njit - def shape_i(x): - return np.asarray(np.shape(x)[i]) - - return shape_i - - @numba_funcify.register(SortOp) def numba_funcify_SortOp(op, node, **kwargs): @numba_njit @@ -423,54 +398,6 @@ def codegen(context, builder, signature, args): return sig, codegen -@numba_funcify.register(Reshape) -def numba_funcify_Reshape(op, **kwargs): - ndim = op.ndim - - if ndim == 0: - - @numba_njit - def reshape(x, shape): - return np.asarray(x.item()) - - else: - - @numba_njit - def reshape(x, shape): - # TODO: Use this until https://github.com/numba/numba/issues/7353 is closed. - return np.reshape( - np.ascontiguousarray(np.asarray(x)), - numba_ndarray.to_fixed_tuple(shape, ndim), - ) - - return reshape - - -@numba_funcify.register(SpecifyShape) -def numba_funcify_SpecifyShape(op, node, **kwargs): - shape_inputs = node.inputs[1:] - shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] - - func_conditions = [ - f"assert x.shape[{i}] == {shape_input_names}" - for i, (shape_input, shape_input_names) in enumerate( - zip(shape_inputs, shape_input_names, strict=True) - ) - if shape_input is not NoneConst - ] - - func = dedent( - f""" - def specify_shape(x, {create_arg_string(shape_input_names)}): - {"; ".join(func_conditions)} - return x - """ - ) - - specify_shape = compile_function_src(func, "specify_shape", globals()) - return numba_njit(specify_shape) - - def int_to_float_fn(inputs, out_dtype): """Create a Numba function that converts integer and boolean ``ndarray``s to floats.""" diff --git a/pytensor/link/numba/dispatch/shape.py b/pytensor/link/numba/dispatch/shape.py new file mode 100644 index 0000000000..1d7aa036a5 --- /dev/null +++ b/pytensor/link/numba/dispatch/shape.py @@ -0,0 +1,78 @@ +from textwrap import dedent + +import numpy as np +from numba.np.unsafe import ndarray as numba_ndarray + +from pytensor.link.numba.dispatch import numba_funcify +from pytensor.link.numba.dispatch.basic import create_arg_string, numba_njit +from pytensor.link.utils import compile_function_src +from pytensor.tensor import NoneConst +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape + + +@numba_funcify.register(Shape) +def numba_funcify_Shape(op, **kwargs): + @numba_njit + def shape(x): + return np.asarray(np.shape(x)) + + return shape + + +@numba_funcify.register(Shape_i) +def numba_funcify_Shape_i(op, **kwargs): + i = op.i + + @numba_njit + def shape_i(x): + return np.asarray(np.shape(x)[i]) + + return shape_i + + +@numba_funcify.register(SpecifyShape) +def numba_funcify_SpecifyShape(op, node, **kwargs): + shape_inputs = node.inputs[1:] + shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] + + func_conditions = [ + f"assert x.shape[{i}] == {shape_input_names}" + for i, (shape_input, shape_input_names) in enumerate( + zip(shape_inputs, shape_input_names, strict=True) + ) + if shape_input is not NoneConst + ] + + func = dedent( + f""" + def specify_shape(x, {create_arg_string(shape_input_names)}): + {"; ".join(func_conditions)} + return x + """ + ) + + specify_shape = compile_function_src(func, "specify_shape", globals()) + return numba_njit(specify_shape) + + +@numba_funcify.register(Reshape) +def numba_funcify_Reshape(op, **kwargs): + ndim = op.ndim + + if ndim == 0: + + @numba_njit + def reshape(x, shape): + return np.asarray(x.item()) + + else: + + @numba_njit + def reshape(x, shape): + # TODO: Use this until https://github.com/numba/numba/issues/7353 is closed. + return np.reshape( + np.ascontiguousarray(np.asarray(x)), + numba_ndarray.to_fixed_tuple(shape, ndim), + ) + + return reshape diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index fb92d1cf0c..2a03e79ac7 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -31,7 +31,6 @@ from pytensor.scalar.basic import ScalarOp, as_scalar from pytensor.tensor import blas, tensor from pytensor.tensor.elemwise import Elemwise -from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.tensor.sort import ArgSortOp, SortOp @@ -332,22 +331,6 @@ def test_create_numba_signature(v, expected, force_scalar): assert res == expected -@pytest.mark.parametrize( - "x, i", - [ - (np.zeros((20, 3)), 1), - ], -) -def test_Shape(x, i): - g = Shape()(pt.as_tensor_variable(x)) - - compare_numba_and_py([], [g], []) - - g = Shape_i(i)(pt.as_tensor_variable(x)) - - compare_numba_and_py([], [g], []) - - @pytest.mark.parametrize( "x", [ @@ -412,81 +395,6 @@ def test_ArgSort(x, axis, kind, exc): compare_numba_and_py([], [g], []) -@pytest.mark.parametrize( - "v, shape, ndim", - [ - ((pt.vector(), np.array([4], dtype=config.floatX)), ((), None), 0), - ((pt.vector(), np.arange(4, dtype=config.floatX)), ((2, 2), None), 2), - ( - (pt.vector(), np.arange(4, dtype=config.floatX)), - (pt.lvector(), np.array([2, 2], dtype="int64")), - 2, - ), - ], -) -def test_Reshape(v, shape, ndim): - v, v_test_value = v - shape, shape_test_value = shape - - g = Reshape(ndim)(v, shape) - inputs = [v] if not isinstance(shape, Variable) else [v, shape] - test_values = ( - [v_test_value] - if not isinstance(shape, Variable) - else [v_test_value, shape_test_value] - ) - compare_numba_and_py( - inputs, - [g], - test_values, - ) - - -def test_Reshape_scalar(): - v = pt.vector() - v_test_value = np.array([1.0], dtype=config.floatX) - g = Reshape(1)(v[0], (1,)) - - compare_numba_and_py( - [v], - g, - [v_test_value], - ) - - -@pytest.mark.parametrize( - "v, shape, fails", - [ - ( - (pt.matrix(), np.array([[1.0]], dtype=config.floatX)), - (1, 1), - False, - ), - ( - (pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), - (1, 1), - True, - ), - ( - (pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), - (1, None), - False, - ), - ], -) -def test_SpecifyShape(v, shape, fails): - v, v_test_value = v - g = SpecifyShape()(v, *shape) - cm = contextlib.suppress() if not fails else pytest.raises(AssertionError) - - with cm: - compare_numba_and_py( - [v], - [g], - [v_test_value], - ) - - def test_ViewOp(): v = pt.vector() v_test_value = np.arange(4, dtype=config.floatX) diff --git a/tests/link/numba/test_shape.py b/tests/link/numba/test_shape.py new file mode 100644 index 0000000000..1412186cf2 --- /dev/null +++ b/tests/link/numba/test_shape.py @@ -0,0 +1,100 @@ +import contextlib + +import numpy as np +import pytest + +from pytensor import Variable, config +from pytensor import tensor as pt +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape +from tests.link.numba.test_basic import compare_numba_and_py + + +@pytest.mark.parametrize( + "x, i", + [ + (np.zeros((20, 3)), 1), + ], +) +def test_Shape(x, i): + g = Shape()(pt.as_tensor_variable(x)) + + compare_numba_and_py([], [g], []) + + g = Shape_i(i)(pt.as_tensor_variable(x)) + + compare_numba_and_py([], [g], []) + + +@pytest.mark.parametrize( + "v, shape, ndim", + [ + ((pt.vector(), np.array([4], dtype=config.floatX)), ((), None), 0), + ((pt.vector(), np.arange(4, dtype=config.floatX)), ((2, 2), None), 2), + ( + (pt.vector(), np.arange(4, dtype=config.floatX)), + (pt.lvector(), np.array([2, 2], dtype="int64")), + 2, + ), + ], +) +def test_Reshape(v, shape, ndim): + v, v_test_value = v + shape, shape_test_value = shape + + g = Reshape(ndim)(v, shape) + inputs = [v] if not isinstance(shape, Variable) else [v, shape] + test_values = ( + [v_test_value] + if not isinstance(shape, Variable) + else [v_test_value, shape_test_value] + ) + compare_numba_and_py( + inputs, + [g], + test_values, + ) + + +def test_Reshape_scalar(): + v = pt.vector() + v_test_value = np.array([1.0], dtype=config.floatX) + g = Reshape(1)(v[0], (1,)) + + compare_numba_and_py( + [v], + g, + [v_test_value], + ) + + +@pytest.mark.parametrize( + "v, shape, fails", + [ + ( + (pt.matrix(), np.array([[1.0]], dtype=config.floatX)), + (1, 1), + False, + ), + ( + (pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), + (1, 1), + True, + ), + ( + (pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), + (1, None), + False, + ), + ], +) +def test_SpecifyShape(v, shape, fails): + v, v_test_value = v + g = SpecifyShape()(v, *shape) + cm = contextlib.suppress() if not fails else pytest.raises(AssertionError) + + with cm: + compare_numba_and_py( + [v], + [g], + [v_test_value], + ) From 6c1c9da5d187d87d585a35a9f45f6edd1c2c6f5d Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 9 Oct 2025 17:10:30 +0200 Subject: [PATCH 07/16] Move sort Ops dispatchers to their own file --- pytensor/link/numba/dispatch/__init__.py | 1 + pytensor/link/numba/dispatch/basic.py | 63 --------------------- pytensor/link/numba/dispatch/sort.py | 69 +++++++++++++++++++++++ tests/link/numba/test_basic.py | 65 --------------------- tests/link/numba/test_sort.py | 72 ++++++++++++++++++++++++ 5 files changed, 142 insertions(+), 128 deletions(-) create mode 100644 pytensor/link/numba/dispatch/sort.py create mode 100644 tests/link/numba/test_sort.py diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index 1541331d31..50e61a27ab 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -12,6 +12,7 @@ import pytensor.link.numba.dispatch.shape import pytensor.link.numba.dispatch.signal import pytensor.link.numba.dispatch.slinalg +import pytensor.link.numba.dispatch.sort import pytensor.link.numba.dispatch.sparse import pytensor.link.numba.dispatch.subtensor import pytensor.link.numba.dispatch.tensor_basic diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 04008c5ed8..d176d72ba8 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -27,7 +27,6 @@ from pytensor.tensor.basic import Nonzero from pytensor.tensor.blas import BatchedDot from pytensor.tensor.math import Dot -from pytensor.tensor.sort import ArgSortOp, SortOp from pytensor.tensor.type import TensorType @@ -317,68 +316,6 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs): return deepcopyop -@numba_funcify.register(SortOp) -def numba_funcify_SortOp(op, node, **kwargs): - @numba_njit - def sort_f(a, axis): - axis = axis.item() - - a_swapped = np.swapaxes(a, axis, -1) - a_sorted = np.sort(a_swapped) - a_sorted_swapped = np.swapaxes(a_sorted, -1, axis) - - return a_sorted_swapped - - if op.kind != "quicksort": - warnings.warn( - ( - f'Numba function sort doesn\'t support kind="{op.kind}"' - " switching to `quicksort`." - ), - UserWarning, - ) - - return sort_f - - -@numba_funcify.register(ArgSortOp) -def numba_funcify_ArgSortOp(op, node, **kwargs): - def argsort_f_kind(kind): - @numba_njit - def argort_vec(X, axis): - axis = axis.item() - - Y = np.swapaxes(X, axis, 0) - result = np.empty_like(Y, dtype="int64") - - indices = list(np.ndindex(Y.shape[1:])) - - for idx in indices: - result[(slice(None), *idx)] = np.argsort( - Y[(slice(None), *idx)], kind=kind - ) - - result = np.swapaxes(result, 0, axis) - - return result - - return argort_vec - - kind = op.kind - - if kind not in ["quicksort", "mergesort"]: - kind = "quicksort" - warnings.warn( - ( - f'Numba function argsort doesn\'t support kind="{op.kind}"' - " switching to `quicksort`." - ), - UserWarning, - ) - - return argsort_f_kind(kind) - - @numba.extending.intrinsic def direct_cast(typingctx, val, typ): if isinstance(typ, numba.types.TypeRef): diff --git a/pytensor/link/numba/dispatch/sort.py b/pytensor/link/numba/dispatch/sort.py new file mode 100644 index 0000000000..a2747bf568 --- /dev/null +++ b/pytensor/link/numba/dispatch/sort.py @@ -0,0 +1,69 @@ +import warnings + +import numpy as np + +from pytensor.link.numba.dispatch import numba_funcify +from pytensor.link.numba.dispatch.basic import numba_njit +from pytensor.tensor.sort import ArgSortOp, SortOp + + +@numba_funcify.register(SortOp) +def numba_funcify_SortOp(op, node, **kwargs): + @numba_njit + def sort_f(a, axis): + axis = axis.item() + + a_swapped = np.swapaxes(a, axis, -1) + a_sorted = np.sort(a_swapped) + a_sorted_swapped = np.swapaxes(a_sorted, -1, axis) + + return a_sorted_swapped + + if op.kind != "quicksort": + warnings.warn( + ( + f'Numba function sort doesn\'t support kind="{op.kind}"' + " switching to `quicksort`." + ), + UserWarning, + ) + + return sort_f + + +@numba_funcify.register(ArgSortOp) +def numba_funcify_ArgSortOp(op, node, **kwargs): + def argsort_f_kind(kind): + @numba_njit + def argort_vec(X, axis): + axis = axis.item() + + Y = np.swapaxes(X, axis, 0) + result = np.empty_like(Y, dtype="int64") + + indices = list(np.ndindex(Y.shape[1:])) + + for idx in indices: + result[(slice(None), *idx)] = np.argsort( + Y[(slice(None), *idx)], kind=kind + ) + + result = np.swapaxes(result, 0, axis) + + return result + + return argort_vec + + kind = op.kind + + if kind not in ["quicksort", "mergesort"]: + kind = "quicksort" + warnings.warn( + ( + f'Numba function argsort doesn\'t support kind="{op.kind}"' + " switching to `quicksort`." + ), + UserWarning, + ) + + return argsort_f_kind(kind) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 2a03e79ac7..5b911e18db 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -31,7 +31,6 @@ from pytensor.scalar.basic import ScalarOp, as_scalar from pytensor.tensor import blas, tensor from pytensor.tensor.elemwise import Elemwise -from pytensor.tensor.sort import ArgSortOp, SortOp if TYPE_CHECKING: @@ -331,70 +330,6 @@ def test_create_numba_signature(v, expected, force_scalar): assert res == expected -@pytest.mark.parametrize( - "x", - [ - [], # Empty list - [3, 2, 1], # Simple list - np.random.randint(0, 10, (3, 2, 3, 4, 4)), # Multi-dimensional array - ], -) -@pytest.mark.parametrize("axis", [0, -1, None]) -@pytest.mark.parametrize( - ("kind", "exc"), - [ - ["quicksort", None], - ["mergesort", UserWarning], - ["heapsort", UserWarning], - ["stable", UserWarning], - ], -) -def test_Sort(x, axis, kind, exc): - if axis: - g = SortOp(kind)(pt.as_tensor_variable(x), axis) - else: - g = SortOp(kind)(pt.as_tensor_variable(x)) - - cm = contextlib.suppress() if not exc else pytest.warns(exc) - - with cm: - compare_numba_and_py([], [g], []) - - -@pytest.mark.parametrize( - "x", - [ - [], # Empty list - [3, 2, 1], # Simple list - None, # Multi-dimensional array (see below) - ], -) -@pytest.mark.parametrize("axis", [0, -1, None]) -@pytest.mark.parametrize( - ("kind", "exc"), - [ - ["quicksort", None], - ["heapsort", None], - ["stable", UserWarning], - ], -) -def test_ArgSort(x, axis, kind, exc): - if x is None: - x = np.arange(5 * 5 * 5 * 5) - np.random.shuffle(x) - x = np.reshape(x, (5, 5, 5, 5)) - - if axis: - g = ArgSortOp(kind)(pt.as_tensor_variable(x), axis) - else: - g = ArgSortOp(kind)(pt.as_tensor_variable(x)) - - cm = contextlib.suppress() if not exc else pytest.warns(exc) - - with cm: - compare_numba_and_py([], [g], []) - - def test_ViewOp(): v = pt.vector() v_test_value = np.arange(4, dtype=config.floatX) diff --git a/tests/link/numba/test_sort.py b/tests/link/numba/test_sort.py new file mode 100644 index 0000000000..d6c6072530 --- /dev/null +++ b/tests/link/numba/test_sort.py @@ -0,0 +1,72 @@ +import contextlib + +import numpy as np +import pytest + +from pytensor import tensor as pt +from pytensor.tensor.sort import ArgSortOp, SortOp +from tests.link.numba.test_basic import compare_numba_and_py + + +@pytest.mark.parametrize( + "x", + [ + [], # Empty list + [3, 2, 1], # Simple list + np.random.randint(0, 10, (3, 2, 3, 4, 4)), # Multi-dimensional array + ], +) +@pytest.mark.parametrize("axis", [0, -1, None]) +@pytest.mark.parametrize( + ("kind", "exc"), + [ + ["quicksort", None], + ["mergesort", UserWarning], + ["heapsort", UserWarning], + ["stable", UserWarning], + ], +) +def test_Sort(x, axis, kind, exc): + if axis: + g = SortOp(kind)(pt.as_tensor_variable(x), axis) + else: + g = SortOp(kind)(pt.as_tensor_variable(x)) + + cm = contextlib.suppress() if not exc else pytest.warns(exc) + + with cm: + compare_numba_and_py([], [g], []) + + +@pytest.mark.parametrize( + "x", + [ + [], # Empty list + [3, 2, 1], # Simple list + None, # Multi-dimensional array (see below) + ], +) +@pytest.mark.parametrize("axis", [0, -1, None]) +@pytest.mark.parametrize( + ("kind", "exc"), + [ + ["quicksort", None], + ["heapsort", None], + ["stable", UserWarning], + ], +) +def test_ArgSort(x, axis, kind, exc): + if x is None: + x = np.arange(5 * 5 * 5 * 5) + np.random.shuffle(x) + x = np.reshape(x, (5, 5, 5, 5)) + + if axis: + g = ArgSortOp(kind)(pt.as_tensor_variable(x), axis) + else: + g = ArgSortOp(kind)(pt.as_tensor_variable(x)) + + cm = contextlib.suppress() if not exc else pytest.warns(exc) + + with cm: + compare_numba_and_py([], [g], []) From 80667205c85634ed2749d774d7a09d89fda08e2a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 9 Oct 2025 17:13:45 +0200 Subject: [PATCH 08/16] Move NonZero Op dispatcher to tensor_basic --- pytensor/link/numba/dispatch/basic.py | 13 ------------- pytensor/link/numba/dispatch/tensor_basic.py | 19 ++++++++++++++++++- tests/link/numba/test_basic.py | 14 -------------- tests/link/numba/test_tensor_basic.py | 14 ++++++++++++++ 4 files changed, 32 insertions(+), 28 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index d176d72ba8..d1fd0edbb1 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -24,7 +24,6 @@ ) from pytensor.scalar.basic import ScalarType from pytensor.sparse import SparseTensorType -from pytensor.tensor.basic import Nonzero from pytensor.tensor.blas import BatchedDot from pytensor.tensor.math import Dot from pytensor.tensor.type import TensorType @@ -457,15 +456,3 @@ def ifelse(cond, *args): return res[0] return ifelse - - -@numba_funcify.register(Nonzero) -def numba_funcify_Nonzero(op, node, **kwargs): - @numba_njit - def nonzero(a): - result_tuple = np.nonzero(a) - if a.ndim == 1: - return result_tuple[0] - return list(result_tuple) - - return nonzero diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 3a9d8767b9..6e77b89826 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -3,7 +3,11 @@ import numpy as np from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch.basic import create_tuple_string, numba_funcify +from pytensor.link.numba.dispatch.basic import ( + create_tuple_string, + numba_funcify, + numba_njit, +) from pytensor.link.utils import compile_function_src, unique_name_generator from pytensor.tensor.basic import ( Alloc, @@ -13,6 +17,7 @@ Eye, Join, MakeVector, + Nonzero, ScalarFromTensor, Split, TensorFromScalar, @@ -235,3 +240,15 @@ def scalar_from_tensor(x): return numba_basic.to_scalar(x) return scalar_from_tensor + + +@numba_funcify.register(Nonzero) +def numba_funcify_Nonzero(op, node, **kwargs): + @numba_njit + def nonzero(a): + result_tuple = np.nonzero(a) + if a.ndim == 1: + return result_tuple[0] + return list(result_tuple) + + return nonzero diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 5b911e18db..2f1954052d 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -718,20 +718,6 @@ def test_function_overhead(mode, benchmark): benchmark(fn, test_x) -@pytest.mark.parametrize( - "input_data", - [np.array([1, 0, 3]), np.array([[0, 1], [2, 0]]), np.array([[0, 0], [0, 0]])], -) -def test_Nonzero(input_data): - a = pt.tensor("a", shape=(None,) * input_data.ndim) - - graph_outputs = pt.nonzero(a) - - compare_numba_and_py( - graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data] - ) - - @pytest.mark.parametrize("dtype", ("float64", "float32", "mixed")) def test_mat_vec_dot_performance(dtype, benchmark): A = tensor("A", shape=(512, 512), dtype="float64" if dtype == "mixed" else dtype) diff --git a/tests/link/numba/test_tensor_basic.py b/tests/link/numba/test_tensor_basic.py index 625246e340..233b7bcb19 100644 --- a/tests/link/numba/test_tensor_basic.py +++ b/tests/link/numba/test_tensor_basic.py @@ -326,3 +326,17 @@ def test_Eye(n, m, k, dtype): g, [n_test, m_test] if m is not None else [n_test], ) + + +@pytest.mark.parametrize( + "input_data", + [np.array([1, 0, 3]), np.array([[0, 1], [2, 0]]), np.array([[0, 0], [0, 0]])], +) +def test_Nonzero(input_data): + a = pt.tensor("a", shape=(None,) * input_data.ndim) + + graph_outputs = pt.nonzero(a) + + compare_numba_and_py( + graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data] + ) From 91cb2c0f1dea9efd05849f1bc3ed284808cf796a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 9 Oct 2025 17:37:09 +0200 Subject: [PATCH 09/16] Move dot Op dispatchers to Elemwise They are actually defined in tensor/math.py, but this is better than being in `basic.py` --- pytensor/link/numba/dispatch/basic.py | 67 ---------------- pytensor/link/numba/dispatch/elemwise.py | 68 ++++++++++++++++- tests/link/numba/test_basic.py | 97 ------------------------ tests/link/numba/test_elemwise.py | 96 +++++++++++++++++++++++ 4 files changed, 163 insertions(+), 165 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index d1fd0edbb1..794130819a 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -24,8 +24,6 @@ ) from pytensor.scalar.basic import ScalarType from pytensor.sparse import SparseTensorType -from pytensor.tensor.blas import BatchedDot -from pytensor.tensor.math import Dot from pytensor.tensor.type import TensorType @@ -364,71 +362,6 @@ def inputs_cast(x): return inputs_cast -@numba_funcify.register(Dot) -def numba_funcify_Dot(op, node, **kwargs): - # Numba's `np.dot` does not support integer dtypes, so we need to cast to float. - x, y = node.inputs - [out] = node.outputs - - x_dtype = x.type.dtype - y_dtype = y.type.dtype - dot_dtype = f"float{max((32, out.type.numpy_dtype.itemsize * 8))}" - out_dtype = out.type.dtype - - if x_dtype == dot_dtype and y_dtype == dot_dtype: - - @numba_njit - def dot(x, y): - return np.asarray(np.dot(x, y)) - - elif x_dtype == dot_dtype and y_dtype != dot_dtype: - - @numba_njit - def dot(x, y): - return np.asarray(np.dot(x, y.astype(dot_dtype))) - - elif x_dtype != dot_dtype and y_dtype == dot_dtype: - - @numba_njit - def dot(x, y): - return np.asarray(np.dot(x.astype(dot_dtype), y)) - - else: - - @numba_njit() - def dot(x, y): - return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype))) - - if out_dtype == dot_dtype: - return dot - - else: - - @numba_njit - def dot_with_cast(x, y): - return dot(x, y).astype(out_dtype) - - return dot_with_cast - - -@numba_funcify.register(BatchedDot) -def numba_funcify_BatchedDot(op, node, **kwargs): - dtype = node.outputs[0].type.numpy_dtype - - @numba_njit - def batched_dot(x, y): - # Numba does not support 3D matmul - # https://github.com/numba/numba/issues/3804 - shape = x.shape[:-1] + y.shape[2:] - z0 = np.empty(shape, dtype=dtype) - for i in range(z0.shape[0]): - z0[i] = np.dot(x[i], y[i]) - - return z0 - - return batched_dot - - @numba_funcify.register(IfElse) def numba_funcify_IfElse(op, **kwargs): n_outs = op.n_outs diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 5ee056f43f..807a60a6d3 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -35,8 +35,9 @@ scalar_maximum, ) from pytensor.scalar.basic import add as add_as +from pytensor.tensor.blas import BatchedDot from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum +from pytensor.tensor.math import Argmax, Dot, MulWithoutZeros, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad @@ -599,3 +600,68 @@ def argmax(x): return max_idx_res return argmax + + +@numba_funcify.register(Dot) +def numba_funcify_Dot(op, node, **kwargs): + # Numba's `np.dot` does not support integer dtypes, so we need to cast to float. + x, y = node.inputs + [out] = node.outputs + + x_dtype = x.type.dtype + y_dtype = y.type.dtype + dot_dtype = f"float{max((32, out.type.numpy_dtype.itemsize * 8))}" + out_dtype = out.type.dtype + + if x_dtype == dot_dtype and y_dtype == dot_dtype: + + @numba_njit + def dot(x, y): + return np.asarray(np.dot(x, y)) + + elif x_dtype == dot_dtype and y_dtype != dot_dtype: + + @numba_njit + def dot(x, y): + return np.asarray(np.dot(x, y.astype(dot_dtype))) + + elif x_dtype != dot_dtype and y_dtype == dot_dtype: + + @numba_njit + def dot(x, y): + return np.asarray(np.dot(x.astype(dot_dtype), y)) + + else: + + @numba_njit() + def dot(x, y): + return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype))) + + if out_dtype == dot_dtype: + return dot + + else: + + @numba_njit + def dot_with_cast(x, y): + return dot(x, y).astype(out_dtype) + + return dot_with_cast + + +@numba_funcify.register(BatchedDot) +def numba_funcify_BatchedDot(op, node, **kwargs): + dtype = node.outputs[0].type.numpy_dtype + + @numba_njit + def batched_dot(x, y): + # Numba does not support 3D matmul + # https://github.com/numba/numba/issues/3804 + shape = x.shape[:-1] + y.shape[2:] + z0 = np.empty(shape, dtype=dtype) + for i in range(z0.shape[0]): + z0[i] = np.dot(x[i], y[i]) + + return z0 + + return batched_dot diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 2f1954052d..49e885abb3 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -14,7 +14,6 @@ import pytensor.scalar as ps import pytensor.tensor as pt -import pytensor.tensor.math as ptm from pytensor import config, shared from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function @@ -29,7 +28,6 @@ from pytensor.link.numba.linker import NumbaLinker from pytensor.raise_op import assert_op from pytensor.scalar.basic import ScalarOp, as_scalar -from pytensor.tensor import blas, tensor from pytensor.tensor.elemwise import Elemwise @@ -407,86 +405,6 @@ def test_perform_type_convert(): compare_numba_and_py([x], out, [x_test_value]) -@pytest.mark.parametrize( - "x, y", - [ - ( - (pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)), - (pt.vector(), rng.random(size=(2,)).astype(config.floatX)), - ), - ( - (pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")), - (pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")), - ), - ( - (pt.lmatrix(), rng.poisson(size=(3, 2))), - (pt.fvector(), rng.random(size=(2,)).astype("float32")), - ), - ( - (pt.lvector(), rng.random(size=(2,)).astype(np.int64)), - (pt.lvector(), rng.random(size=(2,)).astype(np.int64)), - ), - ( - (pt.vector(dtype="int16"), rng.random(size=(2,)).astype(np.int16)), - (pt.vector(dtype="uint8"), rng.random(size=(2,)).astype(np.uint8)), - ), - ], -) -def test_Dot(x, y): - x, x_test_value = x - y, y_test_value = y - - g = ptm.dot(x, y) - - compare_numba_and_py( - [x, y], - [g], - [x_test_value, y_test_value], - ) - - -@pytest.mark.parametrize( - "x, y, exc", - [ - ( - ( - pt.dtensor3(), - rng.random(size=(2, 3, 3)).astype("float64"), - ), - ( - pt.dtensor3(), - rng.random(size=(2, 3, 3)).astype("float64"), - ), - None, - ), - ( - ( - pt.dtensor3(), - rng.random(size=(2, 3, 3)).astype("float64"), - ), - ( - pt.ltensor3(), - rng.poisson(size=(2, 3, 3)).astype("int64"), - ), - None, - ), - ], -) -def test_BatchedDot(x, y, exc): - x, x_test_value = x - y, y_test_value = y - - g = blas.BatchedDot()(x, y) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - [x, y], - g, - [x_test_value, y_test_value], - ) - - def test_shared(): a = shared(np.array([1, 2, 3], dtype=config.floatX)) @@ -716,18 +634,3 @@ def test_function_overhead(mode, benchmark): assert np.sum(fn(test_x)) == 1000 benchmark(fn, test_x) - - -@pytest.mark.parametrize("dtype", ("float64", "float32", "mixed")) -def test_mat_vec_dot_performance(dtype, benchmark): - A = tensor("A", shape=(512, 512), dtype="float64" if dtype == "mixed" else dtype) - x = tensor("x", shape=(512,), dtype="float32" if dtype == "mixed" else dtype) - out = ptm.dot(A, x) - - fn = function([A, x], out, mode="NUMBA", trust_input=True) - - rng = np.random.default_rng(948) - A_test = rng.standard_normal(size=A.type.shape, dtype=A.type.dtype) - x_test = rng.standard_normal(size=x.type.shape, dtype=x.type.dtype) - np.testing.assert_allclose(fn(A_test, x_test), np.dot(A_test, x_test), atol=1e-4) - benchmark(fn, A_test, x_test) diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 84875dac97..954656cebe 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -13,6 +13,7 @@ from pytensor.compile.ops import deep_copy_op from pytensor.gradient import grad from pytensor.scalar import Composite, float64 +from pytensor.tensor import blas, tensor from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad @@ -670,3 +671,98 @@ def test_numba_careduce_benchmark(self, axis, c_contiguous, benchmark): @pytest.mark.parametrize("c_contiguous", (True, False)) def test_dimshuffle(self, c_contiguous, benchmark): dimshuffle_benchmark("NUMBA", c_contiguous, benchmark) + + +@pytest.mark.parametrize( + "x, y", + [ + ( + (pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)), + (pt.vector(), rng.random(size=(2,)).astype(config.floatX)), + ), + ( + (pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")), + (pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")), + ), + ( + (pt.lmatrix(), rng.poisson(size=(3, 2))), + (pt.fvector(), rng.random(size=(2,)).astype("float32")), + ), + ( + (pt.lvector(), rng.random(size=(2,)).astype(np.int64)), + (pt.lvector(), rng.random(size=(2,)).astype(np.int64)), + ), + ( + (pt.vector(dtype="int16"), rng.random(size=(2,)).astype(np.int16)), + (pt.vector(dtype="uint8"), rng.random(size=(2,)).astype(np.uint8)), + ), + ], +) +def test_Dot(x, y): + x, x_test_value = x + y, y_test_value = y + + g = ptm.dot(x, y) + + compare_numba_and_py( + [x, y], + [g], + [x_test_value, y_test_value], + ) + + +@pytest.mark.parametrize( + "x, y, exc", + [ + ( + ( + pt.dtensor3(), + rng.random(size=(2, 3, 3)).astype("float64"), + ), + ( + pt.dtensor3(), + rng.random(size=(2, 3, 3)).astype("float64"), + ), + None, + ), + ( + ( + pt.dtensor3(), + rng.random(size=(2, 3, 3)).astype("float64"), + ), + ( + pt.ltensor3(), + rng.poisson(size=(2, 3, 3)).astype("int64"), + ), + None, + ), + ], +) +def test_BatchedDot(x, y, exc): + x, x_test_value = x + y, y_test_value = y + + g = blas.BatchedDot()(x, y) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + [x, y], + g, + [x_test_value, y_test_value], + ) + + +@pytest.mark.parametrize("dtype", ("float64", "float32", "mixed")) +def test_mat_vec_dot_performance(dtype, benchmark): + A = tensor("A", shape=(512, 512), dtype="float64" if dtype == "mixed" else dtype) + x = tensor("x", shape=(512,), dtype="float32" if dtype == "mixed" else dtype) + out = ptm.dot(A, x) + + fn = function([A, x], out, mode="NUMBA", trust_input=True) + + rng = np.random.default_rng(948) + A_test = rng.standard_normal(size=A.type.shape, dtype=A.type.dtype) + x_test = rng.standard_normal(size=x.type.shape, dtype=x.type.dtype) + np.testing.assert_allclose(fn(A_test, x_test), np.dot(A_test, x_test), atol=1e-4) + benchmark(fn, A_test, x_test) From 0512b32e3d4b5d5e8e2ccb122008d9e5ca15adb5 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 9 Oct 2025 17:28:56 +0200 Subject: [PATCH 10/16] Don't use overload for deepcopy --- pytensor/link/numba/dispatch/basic.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 794130819a..2f5b4b3fd7 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -1,5 +1,4 @@ import warnings -from copy import copy from functools import singledispatch import numba @@ -7,7 +6,6 @@ from numba import types from numba.core.errors import NumbaWarning, TypingError from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 -from numba.extending import overload from pytensor import In, config from pytensor.compile import NUMBA @@ -296,21 +294,21 @@ def numba_funcify_FunctionGraph( ) -def deepcopyop(x): - return copy(x) - +@numba_funcify.register(DeepCopyOp) +def numba_funcify_DeepCopyOp(op, node, **kwargs): + if isinstance(node.inputs[0].type, TensorType): -@overload(deepcopyop) -def dispatch_deepcopyop(x): - if isinstance(x, types.Array): - return lambda x: np.copy(x) + @numba_njit + def deepcopy(x): + return np.copy(x) - return lambda x: x + else: + @numba_njit + def deepcopy(x): + return x -@numba_funcify.register(DeepCopyOp) -def numba_funcify_DeepCopyOp(op, node, **kwargs): - return deepcopyop + return deepcopy @numba.extending.intrinsic From a458be69c977230d4e50e48607b50264663553a4 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 9 Oct 2025 17:29:10 +0200 Subject: [PATCH 11/16] Remove `to_scalar` helper --- pytensor/link/numba/dispatch/basic.py | 17 +----------- pytensor/link/numba/dispatch/extra_ops.py | 8 +++--- pytensor/link/numba/dispatch/scalar.py | 12 +++----- pytensor/link/numba/dispatch/scan.py | 5 ++-- pytensor/link/numba/dispatch/tensor_basic.py | 29 ++++++++++---------- tests/link/numba/test_basic.py | 7 ----- 6 files changed, 25 insertions(+), 53 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 2f5b4b3fd7..a7f216fcbf 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -3,8 +3,7 @@ import numba import numpy as np -from numba import types -from numba.core.errors import NumbaWarning, TypingError +from numba.core.errors import NumbaWarning from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 from pytensor import In, config @@ -135,20 +134,6 @@ def create_numba_signature( return numba.types.void(*input_types) -def to_scalar(x): - return np.asarray(x).item() - - -@numba.extending.overload(to_scalar) -def impl_to_scalar(x): - if isinstance(x, numba.types.Number | numba.types.Boolean): - return lambda x: x - elif isinstance(x, numba.types.Array): - return lambda x: x.item() - else: - raise TypingError(f"{x} must be a scalar compatible type.") - - def create_tuple_creator(f, n): """Construct a compile-time ``tuple``-comprehension-like loop. diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index f7700acf47..5f8495b804 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -26,7 +26,7 @@ def numba_funcify_Bartlett(op, **kwargs): @numba_basic.numba_njit(inline="always") def bartlett(x): - return np.bartlett(numba_basic.to_scalar(x)) + return np.bartlett(x.item()) return bartlett @@ -112,12 +112,12 @@ def numba_funcify_FillDiagonalOffset(op, node, **kwargs): @numba_basic.numba_njit def filldiagonaloffset(a, val, offset): height, width = a.shape - + offset_item = offset.item() if offset >= 0: - start = numba_basic.to_scalar(offset) + start = offset_item num_of_step = min(min(width, height), width - offset) else: - start = -numba_basic.to_scalar(offset) * a.shape[1] + start = -offset_item * a.shape[1] num_of_step = min(min(width, height), height + offset) step = a.shape[1] + 1 diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 4e0019b74b..d5d414f716 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -210,14 +210,10 @@ def identity(x): def numba_funcify_Clip(op, **kwargs): @numba_basic.numba_njit def clip(x, min_val, max_val): - x = numba_basic.to_scalar(x) - min_scalar = numba_basic.to_scalar(min_val) - max_scalar = numba_basic.to_scalar(max_val) - - if x < min_scalar: - return min_scalar - elif x > max_scalar: - return max_scalar + if x < min_val: + return min_val + elif x > max_val: + return max_val else: return x diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index c75a4cf890..694f341ed4 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -365,7 +365,7 @@ def add_output_storage_post_proc_stmt( storage_alloc_stmts.append( dedent( f""" - {storage_size_name} = to_numba_scalar({outer_in_name}) + {storage_size_name} = ({outer_in_name}).item() {storage_name} = np.empty({storage_shape}, dtype=np.{storage_dtype}) """ ).strip() @@ -435,10 +435,9 @@ def scan({", ".join(outer_in_names)}): """ global_env = { + "np": np, "scan_inner_func": scan_inner_func, - "to_numba_scalar": numba_basic.to_scalar, } - global_env["np"] = np scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env}) diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 6e77b89826..c82926364e 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -28,18 +28,17 @@ def numba_funcify_AllocEmpty(op, node, **kwargs): global_env = { "np": np, - "to_scalar": numba_basic.to_scalar, "dtype": np.dtype(op.dtype), } unique_names = unique_name_generator( - ["np", "to_scalar", "dtype", "allocempty", "scalar_shape"], suffix_sep="_" + ["np", "dtype", "allocempty", "scalar_shape"], suffix_sep="_" ) shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs] shape_var_item_names = [f"{name}_item" for name in shape_var_names] shapes_to_items_src = indent( "\n".join( - f"{item_name} = to_scalar({shape_name})" + f"{item_name} = {shape_name}.item()" for item_name, shape_name in zip( shape_var_item_names, shape_var_names, strict=True ) @@ -63,10 +62,10 @@ def allocempty({", ".join(shape_var_names)}): @numba_funcify.register(Alloc) def numba_funcify_Alloc(op, node, **kwargs): - global_env = {"np": np, "to_scalar": numba_basic.to_scalar} + global_env = {"np": np} unique_names = unique_name_generator( - ["np", "to_scalar", "alloc", "val_np", "val", "scalar_shape", "res"], + ["np", "alloc", "val_np", "val", "scalar_shape", "res"], suffix_sep="_", ) shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs[1:]] @@ -110,9 +109,9 @@ def numba_funcify_ARange(op, **kwargs): @numba_basic.numba_njit(inline="always") def arange(start, stop, step): return np.arange( - numba_basic.to_scalar(start), - numba_basic.to_scalar(stop), - numba_basic.to_scalar(step), + start.item(), + stop.item(), + step.item(), dtype=dtype, ) @@ -187,9 +186,9 @@ def numba_funcify_Eye(op, **kwargs): @numba_basic.numba_njit(inline="always") def eye(N, M, k): return np.eye( - numba_basic.to_scalar(N), - numba_basic.to_scalar(M), - numba_basic.to_scalar(k), + N.item(), + M.item(), + k.item(), dtype=dtype, ) @@ -200,16 +199,16 @@ def eye(N, M, k): def numba_funcify_MakeVector(op, node, **kwargs): dtype = np.dtype(op.dtype) - global_env = {"np": np, "to_scalar": numba_basic.to_scalar, "dtype": dtype} + global_env = {"np": np, "dtype": dtype} unique_names = unique_name_generator( - ["np", "to_scalar"], + ["np"], suffix_sep="_", ) input_names = [unique_names(v, force_unique=True) for v in node.inputs] def create_list_string(x): - args = ", ".join([f"to_scalar({i})" for i in x] + ([""] if len(x) == 1 else [])) + args = ", ".join([f"{i}.item()" for i in x] + ([""] if len(x) == 1 else [])) return f"[{args}]" makevector_def_src = f""" @@ -237,7 +236,7 @@ def tensor_from_scalar(x): def numba_funcify_ScalarFromTensor(op, **kwargs): @numba_basic.numba_njit(inline="always") def scalar_from_tensor(x): - return numba_basic.to_scalar(x) + return x.item() return scalar_from_tensor diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 49e885abb3..d706f8a4fd 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -134,12 +134,6 @@ def py_tuple_setitem(t, i, v): ll[i] = v return tuple(ll) - def py_to_scalar(x): - if isinstance(x, np.ndarray): - return x.item() - else: - return x - def njit_noop(*args, **kwargs): if len(args) == 1 and callable(args[0]): return args[0] @@ -155,7 +149,6 @@ def njit_noop(*args, **kwargs): mock.patch( "pytensor.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x ), - mock.patch("pytensor.link.numba.dispatch.basic.to_scalar", py_to_scalar), mock.patch( "pytensor.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype", lambda dtype: dtype, From 96506301d127adbb825d100124f8ac24d742f020 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 9 Oct 2025 17:38:03 +0200 Subject: [PATCH 12/16] Reorder functions in numba/dispatch/basic.py Helpers before dispatchers --- pytensor/link/numba/dispatch/basic.py | 130 +++++++++++++------------- 1 file changed, 65 insertions(+), 65 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index a7f216fcbf..471bb341cb 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -166,6 +166,55 @@ def create_arg_string(x): return args +@numba.extending.intrinsic +def direct_cast(typingctx, val, typ): + if isinstance(typ, numba.types.TypeRef): + casted = typ.instance_type + elif isinstance(typ, numba.types.DTypeSpec): + casted = typ.dtype + else: + casted = typ + + sig = casted(casted, typ) + + def codegen(context, builder, signature, args): + val, _ = args + context.nrt.incref(builder, signature.return_type, val) + return val + + return sig, codegen + + +def int_to_float_fn(inputs, out_dtype): + """Create a Numba function that converts integer and boolean ``ndarray``s to floats.""" + + if ( + all(inp.type.dtype == out_dtype for inp in inputs) + and np.dtype(out_dtype).kind == "f" + ): + + @numba_njit(inline="always") + def inputs_cast(x): + return x + + elif any(i.type.numpy_dtype.kind in "uib" for i in inputs): + args_dtype = np.dtype(f"f{out_dtype.itemsize}") + + @numba_njit(inline="always") + def inputs_cast(x): + return x.astype(args_dtype) + + else: + args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs) + args_dtype = np.dtype(f"f{args_dtype_sz}") + + @numba_njit(inline="always") + def inputs_cast(x): + return x.astype(args_dtype) + + return inputs_cast + + @singledispatch def numba_typify(data, dtype=None, **kwargs): return data @@ -231,6 +280,22 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs): return generate_fallback_impl(op, node, storage_map, **kwargs) +@numba_funcify.register(FunctionGraph) +def numba_funcify_FunctionGraph( + fgraph, + node=None, + fgraph_name="numba_funcified_fgraph", + **kwargs, +): + return fgraph_to_python( + fgraph, + numba_funcify, + type_conversion_fn=numba_typify, + fgraph_name=fgraph_name, + **kwargs, + ) + + @numba_funcify.register(OpFromGraph) def numba_funcify_OpFromGraph(op, node=None, **kwargs): _ = kwargs.pop("storage_map", None) @@ -263,22 +328,6 @@ def opfromgraph(*inputs): return opfromgraph -@numba_funcify.register(FunctionGraph) -def numba_funcify_FunctionGraph( - fgraph, - node=None, - fgraph_name="numba_funcified_fgraph", - **kwargs, -): - return fgraph_to_python( - fgraph, - numba_funcify, - type_conversion_fn=numba_typify, - fgraph_name=fgraph_name, - **kwargs, - ) - - @numba_funcify.register(DeepCopyOp) def numba_funcify_DeepCopyOp(op, node, **kwargs): if isinstance(node.inputs[0].type, TensorType): @@ -296,55 +345,6 @@ def deepcopy(x): return deepcopy -@numba.extending.intrinsic -def direct_cast(typingctx, val, typ): - if isinstance(typ, numba.types.TypeRef): - casted = typ.instance_type - elif isinstance(typ, numba.types.DTypeSpec): - casted = typ.dtype - else: - casted = typ - - sig = casted(casted, typ) - - def codegen(context, builder, signature, args): - val, _ = args - context.nrt.incref(builder, signature.return_type, val) - return val - - return sig, codegen - - -def int_to_float_fn(inputs, out_dtype): - """Create a Numba function that converts integer and boolean ``ndarray``s to floats.""" - - if ( - all(inp.type.dtype == out_dtype for inp in inputs) - and np.dtype(out_dtype).kind == "f" - ): - - @numba_njit(inline="always") - def inputs_cast(x): - return x - - elif any(i.type.numpy_dtype.kind in "uib" for i in inputs): - args_dtype = np.dtype(f"f{out_dtype.itemsize}") - - @numba_njit(inline="always") - def inputs_cast(x): - return x.astype(args_dtype) - - else: - args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs) - args_dtype = np.dtype(f"f{args_dtype_sz}") - - @numba_njit(inline="always") - def inputs_cast(x): - return x.astype(args_dtype) - - return inputs_cast - - @numba_funcify.register(IfElse) def numba_funcify_IfElse(op, **kwargs): n_outs = op.n_outs From 33522f331308afd02624e4855a99bdae28f31ec2 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 9 Oct 2025 17:06:48 +0200 Subject: [PATCH 13/16] Add error message in Numba implementation of SpecifyShape --- pytensor/link/numba/dispatch/shape.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytensor/link/numba/dispatch/shape.py b/pytensor/link/numba/dispatch/shape.py index 1d7aa036a5..f7f2c0890d 100644 --- a/pytensor/link/numba/dispatch/shape.py +++ b/pytensor/link/numba/dispatch/shape.py @@ -36,11 +36,11 @@ def numba_funcify_SpecifyShape(op, node, **kwargs): shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] func_conditions = [ - f"assert x.shape[{i}] == {shape_input_names}" - for i, (shape_input, shape_input_names) in enumerate( + f"assert x.shape[{i}] == {eval_dim_name}, f'SpecifyShape: dim {{{i}}} of input has shape {{x.shape[{i}]}}, expected {{{eval_dim_name}.item()}}.'" + for i, (node_dim_input, eval_dim_name) in enumerate( zip(shape_inputs, shape_input_names, strict=True) ) - if shape_input is not NoneConst + if node_dim_input is not NoneConst ] func = dedent( From 165520dd68e9d65bfbd9f6bb201f7c532469eeee Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 9 Oct 2025 17:40:18 +0200 Subject: [PATCH 14/16] Cleanup sort Op dispatchers --- pytensor/link/numba/dispatch/sort.py | 56 +++++++++++++--------------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/pytensor/link/numba/dispatch/sort.py b/pytensor/link/numba/dispatch/sort.py index a2747bf568..bb91d4fc97 100644 --- a/pytensor/link/numba/dispatch/sort.py +++ b/pytensor/link/numba/dispatch/sort.py @@ -9,6 +9,15 @@ @numba_funcify.register(SortOp) def numba_funcify_SortOp(op, node, **kwargs): + if op.kind != "quicksort": + warnings.warn( + ( + f'Numba function sort doesn\'t support kind="{op.kind}"' + " switching to `quicksort`." + ), + UserWarning, + ) + @numba_njit def sort_f(a, axis): axis = axis.item() @@ -19,41 +28,11 @@ def sort_f(a, axis): return a_sorted_swapped - if op.kind != "quicksort": - warnings.warn( - ( - f'Numba function sort doesn\'t support kind="{op.kind}"' - " switching to `quicksort`." - ), - UserWarning, - ) - return sort_f @numba_funcify.register(ArgSortOp) def numba_funcify_ArgSortOp(op, node, **kwargs): - def argsort_f_kind(kind): - @numba_njit - def argort_vec(X, axis): - axis = axis.item() - - Y = np.swapaxes(X, axis, 0) - result = np.empty_like(Y, dtype="int64") - - indices = list(np.ndindex(Y.shape[1:])) - - for idx in indices: - result[(slice(None), *idx)] = np.argsort( - Y[(slice(None), *idx)], kind=kind - ) - - result = np.swapaxes(result, 0, axis) - - return result - - return argort_vec - kind = op.kind if kind not in ["quicksort", "mergesort"]: @@ -66,4 +45,19 @@ def argort_vec(X, axis): UserWarning, ) - return argsort_f_kind(kind) + @numba_njit + def argort_f(X, axis): + axis = axis.item() + + Y = np.swapaxes(X, axis, 0) + result = np.empty_like(Y, dtype="int64") + + indices = list(np.ndindex(Y.shape[1:])) + + for idx in indices: + result[(slice(None), *idx)] = np.argsort(Y[(slice(None), *idx)], kind=kind) + + result = np.swapaxes(result, 0, axis) + return result + + return argort_f From ea088643f8900e0fec18a5be041b0618ea6af306 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 9 Oct 2025 17:44:21 +0200 Subject: [PATCH 15/16] Move TypeCastingOp dispatcher to basic.py This isn't strictly needed but it's a more intuitive placement --- pytensor/link/numba/dispatch/basic.py | 11 ++++++++++- pytensor/link/numba/dispatch/scalar.py | 2 -- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 471bb341cb..0d4217a786 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -10,7 +10,7 @@ from pytensor.compile import NUMBA from pytensor.compile.builders import OpFromGraph from pytensor.compile.function.types import add_supervisor_to_fgraph -from pytensor.compile.ops import DeepCopyOp +from pytensor.compile.ops import DeepCopyOp, TypeCastingOp from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph from pytensor.graph.type import Type @@ -328,6 +328,15 @@ def opfromgraph(*inputs): return opfromgraph +@numba_funcify.register(TypeCastingOp) +def numba_funcify_type_casting(op, **kwargs): + @numba_njit + def identity(x): + return x + + return identity + + @numba_funcify.register(DeepCopyOp) def numba_funcify_DeepCopyOp(op, node, **kwargs): if isinstance(node.inputs[0].type, TensorType): diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index d5d414f716..e26c9371ed 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -2,7 +2,6 @@ import numpy as np -from pytensor.compile.ops import TypeCastingOp from pytensor.graph.basic import Variable from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( @@ -197,7 +196,6 @@ def cast(x): @numba_funcify.register(Identity) -@numba_funcify.register(TypeCastingOp) def numba_funcify_type_casting(op, **kwargs): @numba_basic.numba_njit def identity(x): From c594eb41c634c6826a1e80f70c34c98ce9c57946 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 14 Oct 2025 14:40:21 +0200 Subject: [PATCH 16/16] Test numba slice boxing and fix representation of None stop with negative step --- pytensor/link/numba/dispatch/subtensor.py | 34 ++++++++++++------- tests/link/numba/test_subtensor.py | 41 +++++++++++++++++++++++ 2 files changed, 63 insertions(+), 12 deletions(-) diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index 4727eaf337..5aade827cb 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -46,33 +46,43 @@ def box_slice(typ, val, c): """ start = c.builder.extract_value(val, 0) stop = c.builder.extract_value(val, 1) + step = c.builder.extract_value(val, 2) if typ.has_step else None + # Numba uses sys.maxsize and -sys.maxsize-1 to represent None + # We want to use None in the Python representation none_val = ir.Constant(ir.IntType(64), sys.maxsize) + neg_none_val = ir.Constant(ir.IntType(64), -sys.maxsize - 1) + none_obj = c.pyapi.get_null_object() - start_is_none = c.builder.icmp_signed("==", start, none_val) start = c.builder.select( - start_is_none, - c.pyapi.get_null_object(), + c.builder.icmp_signed("==", start, none_val), + none_obj, c.box(types.int64, start), ) - stop_is_none = c.builder.icmp_signed("==", stop, none_val) + # None stop is represented as neg_none_val when step is negative + if step is not None: + stop_none_val = c.builder.select( + c.builder.icmp_signed(">", step, ir.Constant(ir.IntType(64), 0)), + none_val, + neg_none_val, + ) + else: + stop_none_val = none_val stop = c.builder.select( - stop_is_none, - c.pyapi.get_null_object(), + c.builder.icmp_signed("==", stop, stop_none_val), + none_obj, c.box(types.int64, stop), ) - if typ.has_step: - step = c.builder.extract_value(val, 2) - step_is_none = c.builder.icmp_signed("==", step, none_val) + if step is not None: step = c.builder.select( - step_is_none, - c.pyapi.get_null_object(), + c.builder.icmp_signed("==", step, none_val), + none_obj, c.box(types.int64, step), ) else: - step = c.pyapi.get_null_object() + step = none_obj slice_val = slice_new(c.pyapi, start, stop, step) diff --git a/tests/link/numba/test_subtensor.py b/tests/link/numba/test_subtensor.py index c9578657f2..17adb892cd 100644 --- a/tests/link/numba/test_subtensor.py +++ b/tests/link/numba/test_subtensor.py @@ -3,7 +3,9 @@ import numpy as np import pytest +import pytensor.scalar as ps import pytensor.tensor as pt +from pytensor import Mode, as_symbolic from pytensor.tensor import as_tensor from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, @@ -24,6 +26,45 @@ rng = np.random.default_rng(sum(map(ord, "Numba subtensors"))) +@pytest.mark.parametrize("step", [None, 1, 2, -2, "x"], ids=lambda x: f"step={x}") +@pytest.mark.parametrize("stop", [None, 10, "x"], ids=lambda x: f"stop={x}") +@pytest.mark.parametrize("start", [None, 0, 3, "x"], ids=lambda x: f"start={x}") +def test_slice(start, stop, step): + x = ps.int64("x") + + sym_slice = as_symbolic( + slice( + x if start == "x" else start, + x if stop == "x" else stop, + x if step == "x" else step, + ) + ) + + no_opt_mode = Mode(linker="numba", optimizer=None) + evaled_slice = sym_slice.eval({x: -5}, on_unused_input="ignore", mode=no_opt_mode) + assert isinstance(evaled_slice, slice) + if start == "x": + assert evaled_slice.start == -5 + elif start is None and (evaled_slice.step is None or evaled_slice.step > 0): + # Numba can convert to 0 (and sometimes does) in this case + assert evaled_slice.start in (None, 0) + else: + assert evaled_slice.start == start + + if stop == "x": + assert evaled_slice.stop == -5 + else: + assert evaled_slice.stop == stop + + if step == "x": + assert evaled_slice.step == -5 + elif step is None: + # Numba can convert to 1 (and sometimes does) in this case + assert evaled_slice.step in (None, 1) + else: + assert evaled_slice.step == step + + @pytest.mark.parametrize( "x, indices", [