Skip to content

Commit 605d2ff

Browse files
committed
Reorder functions in numba/dispatch/basic.py
Helpers before dispatchers
1 parent 4cc47bb commit 605d2ff

File tree

1 file changed

+65
-65
lines changed

1 file changed

+65
-65
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 65 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,55 @@ def create_arg_string(x):
166166
return args
167167

168168

169+
@numba.extending.intrinsic
170+
def direct_cast(typingctx, val, typ):
171+
if isinstance(typ, numba.types.TypeRef):
172+
casted = typ.instance_type
173+
elif isinstance(typ, numba.types.DTypeSpec):
174+
casted = typ.dtype
175+
else:
176+
casted = typ
177+
178+
sig = casted(casted, typ)
179+
180+
def codegen(context, builder, signature, args):
181+
val, _ = args
182+
context.nrt.incref(builder, signature.return_type, val)
183+
return val
184+
185+
return sig, codegen
186+
187+
188+
def int_to_float_fn(inputs, out_dtype):
189+
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
190+
191+
if (
192+
all(inp.type.dtype == out_dtype for inp in inputs)
193+
and np.dtype(out_dtype).kind == "f"
194+
):
195+
196+
@numba_njit(inline="always")
197+
def inputs_cast(x):
198+
return x
199+
200+
elif any(i.type.numpy_dtype.kind in "uib" for i in inputs):
201+
args_dtype = np.dtype(f"f{out_dtype.itemsize}")
202+
203+
@numba_njit(inline="always")
204+
def inputs_cast(x):
205+
return x.astype(args_dtype)
206+
207+
else:
208+
args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs)
209+
args_dtype = np.dtype(f"f{args_dtype_sz}")
210+
211+
@numba_njit(inline="always")
212+
def inputs_cast(x):
213+
return x.astype(args_dtype)
214+
215+
return inputs_cast
216+
217+
169218
@singledispatch
170219
def numba_typify(data, dtype=None, **kwargs):
171220
return data
@@ -231,6 +280,22 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
231280
return generate_fallback_impl(op, node, storage_map, **kwargs)
232281

233282

283+
@numba_funcify.register(FunctionGraph)
284+
def numba_funcify_FunctionGraph(
285+
fgraph,
286+
node=None,
287+
fgraph_name="numba_funcified_fgraph",
288+
**kwargs,
289+
):
290+
return fgraph_to_python(
291+
fgraph,
292+
numba_funcify,
293+
type_conversion_fn=numba_typify,
294+
fgraph_name=fgraph_name,
295+
**kwargs,
296+
)
297+
298+
234299
@numba_funcify.register(OpFromGraph)
235300
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
236301
_ = kwargs.pop("storage_map", None)
@@ -263,22 +328,6 @@ def opfromgraph(*inputs):
263328
return opfromgraph
264329

265330

266-
@numba_funcify.register(FunctionGraph)
267-
def numba_funcify_FunctionGraph(
268-
fgraph,
269-
node=None,
270-
fgraph_name="numba_funcified_fgraph",
271-
**kwargs,
272-
):
273-
return fgraph_to_python(
274-
fgraph,
275-
numba_funcify,
276-
type_conversion_fn=numba_typify,
277-
fgraph_name=fgraph_name,
278-
**kwargs,
279-
)
280-
281-
282331
@numba_funcify.register(DeepCopyOp)
283332
def numba_funcify_DeepCopyOp(op, node, **kwargs):
284333
if isinstance(node.inputs[0].type, TensorType):
@@ -296,55 +345,6 @@ def deepcopy(x):
296345
return deepcopy
297346

298347

299-
@numba.extending.intrinsic
300-
def direct_cast(typingctx, val, typ):
301-
if isinstance(typ, numba.types.TypeRef):
302-
casted = typ.instance_type
303-
elif isinstance(typ, numba.types.DTypeSpec):
304-
casted = typ.dtype
305-
else:
306-
casted = typ
307-
308-
sig = casted(casted, typ)
309-
310-
def codegen(context, builder, signature, args):
311-
val, _ = args
312-
context.nrt.incref(builder, signature.return_type, val)
313-
return val
314-
315-
return sig, codegen
316-
317-
318-
def int_to_float_fn(inputs, out_dtype):
319-
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
320-
321-
if (
322-
all(inp.type.dtype == out_dtype for inp in inputs)
323-
and np.dtype(out_dtype).kind == "f"
324-
):
325-
326-
@numba_njit(inline="always")
327-
def inputs_cast(x):
328-
return x
329-
330-
elif any(i.type.numpy_dtype.kind in "uib" for i in inputs):
331-
args_dtype = np.dtype(f"f{out_dtype.itemsize}")
332-
333-
@numba_njit(inline="always")
334-
def inputs_cast(x):
335-
return x.astype(args_dtype)
336-
337-
else:
338-
args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs)
339-
args_dtype = np.dtype(f"f{args_dtype_sz}")
340-
341-
@numba_njit(inline="always")
342-
def inputs_cast(x):
343-
return x.astype(args_dtype)
344-
345-
return inputs_cast
346-
347-
348348
@numba_funcify.register(IfElse)
349349
def numba_funcify_IfElse(op, **kwargs):
350350
n_outs = op.n_outs

0 commit comments

Comments
 (0)