Skip to content

Commit 4ef29b1

Browse files
committed
.wip
1 parent 3677f33 commit 4ef29b1

File tree

7 files changed

+42
-47
lines changed

7 files changed

+42
-47
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ def global_numba_func(func):
5050
return func
5151

5252

53-
def numba_njit(*args, fastmath=None, register_jitable: bool = False, **kwargs):
53+
def numba_njit(*args, fastmath=None, register_jitable: bool = True, **kwargs):
5454
kwargs.setdefault("cache", True)
5555
kwargs.setdefault("no_cpython_wrapper", False)
5656
kwargs.setdefault("no_cfunc_wrapper", False)
57-
# print(kwargs)
57+
5858
if fastmath is None:
5959
if config.numba__fastmath:
6060
# Opinionated default on fastmath flags
@@ -380,11 +380,16 @@ def numba_funcify_FunctionGraph(
380380
fgraph,
381381
node=None,
382382
fgraph_name="numba_funcified_fgraph",
383+
jit_nodes: bool = False,
383384
**kwargs,
384385
):
386+
def numba_funcify_njit(op, node, **kwargs):
387+
jitable_func = numba_funcify(op, node=node, **kwargs)
388+
return numba_njit(lambda *args: jitable_func(*args), register_jitable=False)
389+
385390
return fgraph_to_python(
386391
fgraph,
387-
numba_funcify,
392+
op_conversion_fn=numba_funcify_njit if jit_nodes else numba_funcify,
388393
type_conversion_fn=numba_typify,
389394
fgraph_name=fgraph_name,
390395
**kwargs,

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -320,33 +320,23 @@ def elemwise_wrapper(*inputs):
320320

321321
# Pure python implementation, that will be used in tests
322322
def elemwise(*inputs):
323-
inputs = [np.asarray(input) for input in inputs]
323+
Elemwise._check_runtime_broadcast(node, inputs)
324324
inputs_bc = np.broadcast_arrays(*inputs)
325-
shape = inputs[0].shape
326-
for input, bc in zip(inputs, input_bc_patterns, strict=True):
327-
for length, allow_bc, iter_length in zip(
328-
input.shape, bc, shape, strict=True
329-
):
330-
if length == 1 and shape and iter_length != 1 and not allow_bc:
331-
raise ValueError("Broadcast not allowed.")
332-
333-
outputs = [np.empty(shape, dtype=dtype) for dtype in output_dtypes]
334-
335-
for idx in np.ndindex(shape):
336-
vals = [input[idx] for input in inputs_bc]
337-
outs = scalar_op_fn(*vals)
338-
if not isinstance(outs, tuple):
339-
outs = (outs,)
340-
for out, out_val in zip(outputs, outs, strict=True):
341-
out[idx] = out_val
342-
343-
outputs_summed = []
344-
for output, bc in zip(outputs, output_bc_patterns, strict=True):
345-
axes = tuple(np.nonzero(bc)[0])
346-
outputs_summed.append(output.sum(axes, keepdims=True))
347-
if len(outputs_summed) != 1:
348-
return tuple(outputs_summed)
349-
return outputs_summed[0]
325+
shape = inputs_bc[0].shape
326+
327+
if len(output_dtypes) == 1:
328+
output = np.empty(shape, dtype=output_dtypes[0])
329+
for idx in np.ndindex(shape):
330+
output[idx] = scalar_op_fn(*(inp[idx] for inp in inputs_bc))
331+
return output
332+
333+
else:
334+
outputs = [np.empty(shape, dtype=dtype) for dtype in output_dtypes]
335+
for idx in np.ndindex(shape):
336+
outs_vals = scalar_op_fn(*(inp[idx] for inp in inputs_bc))
337+
for out, out_val in zip(outputs, outs_vals):
338+
out[idx] = out_val
339+
return outputs
350340

351341
@overload(elemwise)
352342
def ov_elemwise(*inputs):
@@ -594,7 +584,7 @@ def numba_funcify_Argmax(op, node, **kwargs):
594584

595585
if x_ndim == 0:
596586

597-
@numba_basic.numba_njit(inline="always")
587+
@numba_basic.numba_njit
598588
def argmax(x):
599589
return np.array(0, dtype="int64")
600590

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
@numba_funcify.register(Bartlett)
2626
def numba_funcify_Bartlett(op, **kwargs):
27-
@numba_basic.numba_njit(inline="always")
27+
@numba_basic.numba_njit
2828
def bartlett(x):
2929
return np.bartlett(numba_basic.to_scalar(x))
3030

@@ -228,13 +228,13 @@ def repeatop(x, repeats):
228228

229229
if repeats_ndim == 0:
230230

231-
@numba_basic.numba_njit(inline="always")
231+
@numba_basic.numba_njit
232232
def repeatop(x, repeats):
233233
return np.repeat(x, repeats.item())
234234

235235
else:
236236

237-
@numba_basic.numba_njit(inline="always")
237+
@numba_basic.numba_njit
238238
def repeatop(x, repeats):
239239
return np.repeat(x, repeats)
240240

@@ -348,7 +348,7 @@ def searchsorted(a, v, sorter):
348348

349349
else:
350350

351-
@numba_basic.numba_njit(inline="always")
351+
@numba_basic.numba_njit
352352
def searchsorted(a, v):
353353
return np.searchsorted(a, v, side)
354354

pytensor/link/numba/dispatch/nlinalg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def numba_funcify_Det(op, node, **kwargs):
4949
out_dtype = node.outputs[0].type.numpy_dtype
5050
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
5151

52-
@numba_basic.numba_njit(inline="always")
52+
@numba_basic.numba_njit
5353
def det(x):
5454
return np.array(np.linalg.det(inputs_cast(x))).astype(out_dtype)
5555

@@ -128,7 +128,7 @@ def numba_funcify_MatrixInverse(op, node, **kwargs):
128128
out_dtype = node.outputs[0].type.numpy_dtype
129129
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
130130

131-
@numba_basic.numba_njit(inline="always")
131+
@numba_basic.numba_njit
132132
def matrix_inverse(x):
133133
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
134134

@@ -140,7 +140,7 @@ def numba_funcify_MatrixPinv(op, node, **kwargs):
140140
out_dtype = node.outputs[0].type.numpy_dtype
141141
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
142142

143-
@numba_basic.numba_njit(inline="always")
143+
@numba_basic.numba_njit
144144
def matrixpinv(x):
145145
return np.linalg.pinv(inputs_cast(x)).astype(out_dtype)
146146

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def numba_funcify_LU(op, node, **kwargs):
118118
if dtype in complex_dtypes:
119119
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
120120

121-
@numba_njit(inline="always")
121+
@numba_njit
122122
def lu(a):
123123
if check_finite:
124124
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,12 @@ def alloc(val, {", ".join(shape_var_names)}):
112112
def numba_funcify_ARange(op, **kwargs):
113113
dtype = np.dtype(op.dtype)
114114

115-
@numba_basic.numba_njit(inline="always")
115+
@numba_basic.numba_njit
116116
def arange(start, stop, step):
117117
return np.arange(
118-
numba_basic.to_scalar(start),
119-
numba_basic.to_scalar(stop),
120-
numba_basic.to_scalar(step),
118+
start.item(),
119+
stop.item(),
120+
step.item(),
121121
dtype=dtype,
122122
)
123123

@@ -164,7 +164,7 @@ def extract_diag(x):
164164
leading_dims = (slice(None),) * axis1
165165
middle_dims = (slice(None),) * (axis2 - axis1 - 1)
166166

167-
@numba_basic.numba_njit(inline="always")
167+
@numba_basic.numba_njit
168168
def extract_diag(x):
169169
if offset >= 0:
170170
diag_len = min(x.shape[axis1], max(0, x.shape[axis2] - offset))
@@ -234,7 +234,7 @@ def makevector({", ".join(input_names)}):
234234

235235
@numba_funcify.register(TensorFromScalar)
236236
def numba_funcify_TensorFromScalar(op, **kwargs):
237-
@numba_basic.numba_njit(inline="always")
237+
@numba_basic.numba_njit
238238
def tensor_from_scalar(x):
239239
return np.array(x)
240240

@@ -243,8 +243,8 @@ def tensor_from_scalar(x):
243243

244244
@numba_funcify.register(ScalarFromTensor)
245245
def numba_funcify_ScalarFromTensor(op, **kwargs):
246-
@numba_basic.numba_njit(inline="always")
246+
@numba_basic.numba_njit
247247
def scalar_from_tensor(x):
248-
return numba_basic.to_scalar(x)
248+
return x.item()
249249

250250
return scalar_from_tensor

pytensor/link/numba/linker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self, *args, vm: bool = False, **kwargs):
1111
def fgraph_convert(self, fgraph, **kwargs):
1212
from pytensor.link.numba.dispatch import numba_funcify
1313

14-
return numba_funcify(fgraph, **kwargs)
14+
return numba_funcify(fgraph, jit_nodes=self.vm, **kwargs)
1515

1616
def jit_compile(self, fn):
1717
if self.vm:

0 commit comments

Comments
 (0)