Skip to content

Commit a179e6e

Browse files
committed
Remove to_scalar helper
1 parent 7e9b3f8 commit a179e6e

File tree

6 files changed

+25
-53
lines changed

6 files changed

+25
-53
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33

44
import numba
55
import numpy as np
6-
from numba import types
7-
from numba.core.errors import NumbaWarning, TypingError
6+
from numba.core.errors import NumbaWarning
87
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
98

109
from pytensor import In, config
@@ -135,20 +134,6 @@ def create_numba_signature(
135134
return numba.types.void(*input_types)
136135

137136

138-
def to_scalar(x):
139-
return np.asarray(x).item()
140-
141-
142-
@numba.extending.overload(to_scalar)
143-
def impl_to_scalar(x):
144-
if isinstance(x, numba.types.Number | numba.types.Boolean):
145-
return lambda x: x
146-
elif isinstance(x, numba.types.Array):
147-
return lambda x: x.item()
148-
else:
149-
raise TypingError(f"{x} must be a scalar compatible type.")
150-
151-
152137
def create_tuple_creator(f, n):
153138
"""Construct a compile-time ``tuple``-comprehension-like loop.
154139

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
def numba_funcify_Bartlett(op, **kwargs):
2727
@numba_basic.numba_njit(inline="always")
2828
def bartlett(x):
29-
return np.bartlett(numba_basic.to_scalar(x))
29+
return np.bartlett(x.item())
3030

3131
return bartlett
3232

@@ -112,12 +112,12 @@ def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
112112
@numba_basic.numba_njit
113113
def filldiagonaloffset(a, val, offset):
114114
height, width = a.shape
115-
115+
offset_item = offset.item()
116116
if offset >= 0:
117-
start = numba_basic.to_scalar(offset)
117+
start = offset_item
118118
num_of_step = min(min(width, height), width - offset)
119119
else:
120-
start = -numba_basic.to_scalar(offset) * a.shape[1]
120+
start = -offset_item * a.shape[1]
121121
num_of_step = min(min(width, height), height + offset)
122122

123123
step = a.shape[1] + 1

pytensor/link/numba/dispatch/scalar.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -210,14 +210,10 @@ def identity(x):
210210
def numba_funcify_Clip(op, **kwargs):
211211
@numba_basic.numba_njit
212212
def clip(x, min_val, max_val):
213-
x = numba_basic.to_scalar(x)
214-
min_scalar = numba_basic.to_scalar(min_val)
215-
max_scalar = numba_basic.to_scalar(max_val)
216-
217-
if x < min_scalar:
218-
return min_scalar
219-
elif x > max_scalar:
220-
return max_scalar
213+
if x < min_val:
214+
return min_val
215+
elif x > max_val:
216+
return max_val
221217
else:
222218
return x
223219

pytensor/link/numba/dispatch/scan.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def add_output_storage_post_proc_stmt(
365365
storage_alloc_stmts.append(
366366
dedent(
367367
f"""
368-
{storage_size_name} = to_numba_scalar({outer_in_name})
368+
{storage_size_name} = ({outer_in_name}).item()
369369
{storage_name} = np.empty({storage_shape}, dtype=np.{storage_dtype})
370370
"""
371371
).strip()
@@ -435,10 +435,9 @@ def scan({", ".join(outer_in_names)}):
435435
"""
436436

437437
global_env = {
438+
"np": np,
438439
"scan_inner_func": scan_inner_func,
439-
"to_numba_scalar": numba_basic.to_scalar,
440440
}
441-
global_env["np"] = np
442441

443442
scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env})
444443

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,17 @@
2828
def numba_funcify_AllocEmpty(op, node, **kwargs):
2929
global_env = {
3030
"np": np,
31-
"to_scalar": numba_basic.to_scalar,
3231
"dtype": np.dtype(op.dtype),
3332
}
3433

3534
unique_names = unique_name_generator(
36-
["np", "to_scalar", "dtype", "allocempty", "scalar_shape"], suffix_sep="_"
35+
["np", "dtype", "allocempty", "scalar_shape"], suffix_sep="_"
3736
)
3837
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs]
3938
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
4039
shapes_to_items_src = indent(
4140
"\n".join(
42-
f"{item_name} = to_scalar({shape_name})"
41+
f"{item_name} = {shape_name}.item()"
4342
for item_name, shape_name in zip(
4443
shape_var_item_names, shape_var_names, strict=True
4544
)
@@ -63,10 +62,10 @@ def allocempty({", ".join(shape_var_names)}):
6362

6463
@numba_funcify.register(Alloc)
6564
def numba_funcify_Alloc(op, node, **kwargs):
66-
global_env = {"np": np, "to_scalar": numba_basic.to_scalar}
65+
global_env = {"np": np}
6766

6867
unique_names = unique_name_generator(
69-
["np", "to_scalar", "alloc", "val_np", "val", "scalar_shape", "res"],
68+
["np", "alloc", "val_np", "val", "scalar_shape", "res"],
7069
suffix_sep="_",
7170
)
7271
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs[1:]]
@@ -110,9 +109,9 @@ def numba_funcify_ARange(op, **kwargs):
110109
@numba_basic.numba_njit(inline="always")
111110
def arange(start, stop, step):
112111
return np.arange(
113-
numba_basic.to_scalar(start),
114-
numba_basic.to_scalar(stop),
115-
numba_basic.to_scalar(step),
112+
start.item(),
113+
stop.item(),
114+
step.item(),
116115
dtype=dtype,
117116
)
118117

@@ -187,9 +186,9 @@ def numba_funcify_Eye(op, **kwargs):
187186
@numba_basic.numba_njit(inline="always")
188187
def eye(N, M, k):
189188
return np.eye(
190-
numba_basic.to_scalar(N),
191-
numba_basic.to_scalar(M),
192-
numba_basic.to_scalar(k),
189+
N.item(),
190+
M.item(),
191+
k.item(),
193192
dtype=dtype,
194193
)
195194

@@ -200,16 +199,16 @@ def eye(N, M, k):
200199
def numba_funcify_MakeVector(op, node, **kwargs):
201200
dtype = np.dtype(op.dtype)
202201

203-
global_env = {"np": np, "to_scalar": numba_basic.to_scalar, "dtype": dtype}
202+
global_env = {"np": np, "dtype": dtype}
204203

205204
unique_names = unique_name_generator(
206-
["np", "to_scalar"],
205+
["np"],
207206
suffix_sep="_",
208207
)
209208
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
210209

211210
def create_list_string(x):
212-
args = ", ".join([f"to_scalar({i})" for i in x] + ([""] if len(x) == 1 else []))
211+
args = ", ".join([f"{i}.item()" for i in x] + ([""] if len(x) == 1 else []))
213212
return f"[{args}]"
214213

215214
makevector_def_src = f"""
@@ -237,7 +236,7 @@ def tensor_from_scalar(x):
237236
def numba_funcify_ScalarFromTensor(op, **kwargs):
238237
@numba_basic.numba_njit(inline="always")
239238
def scalar_from_tensor(x):
240-
return numba_basic.to_scalar(x)
239+
return x.item()
241240

242241
return scalar_from_tensor
243242

tests/link/numba/test_basic.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,6 @@ def py_tuple_setitem(t, i, v):
135135
ll[i] = v
136136
return tuple(ll)
137137

138-
def py_to_scalar(x):
139-
if isinstance(x, np.ndarray):
140-
return x.item()
141-
else:
142-
return x
143-
144138
def njit_noop(*args, **kwargs):
145139
if len(args) == 1 and callable(args[0]):
146140
return args[0]
@@ -180,7 +174,6 @@ def inner_vec(*args):
180174
mock.patch(
181175
"pytensor.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x
182176
),
183-
mock.patch("pytensor.link.numba.dispatch.basic.to_scalar", py_to_scalar),
184177
mock.patch(
185178
"pytensor.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype",
186179
lambda dtype: dtype,

0 commit comments

Comments
 (0)