@@ -166,6 +166,55 @@ def create_arg_string(x):
166166 return args
167167
168168
169+ @numba .extending .intrinsic
170+ def direct_cast (typingctx , val , typ ):
171+ if isinstance (typ , numba .types .TypeRef ):
172+ casted = typ .instance_type
173+ elif isinstance (typ , numba .types .DTypeSpec ):
174+ casted = typ .dtype
175+ else :
176+ casted = typ
177+
178+ sig = casted (casted , typ )
179+
180+ def codegen (context , builder , signature , args ):
181+ val , _ = args
182+ context .nrt .incref (builder , signature .return_type , val )
183+ return val
184+
185+ return sig , codegen
186+
187+
188+ def int_to_float_fn (inputs , out_dtype ):
189+ """Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
190+
191+ if (
192+ all (inp .type .dtype == out_dtype for inp in inputs )
193+ and np .dtype (out_dtype ).kind == "f"
194+ ):
195+
196+ @numba_njit (inline = "always" )
197+ def inputs_cast (x ):
198+ return x
199+
200+ elif any (i .type .numpy_dtype .kind in "uib" for i in inputs ):
201+ args_dtype = np .dtype (f"f{ out_dtype .itemsize } " )
202+
203+ @numba_njit (inline = "always" )
204+ def inputs_cast (x ):
205+ return x .astype (args_dtype )
206+
207+ else :
208+ args_dtype_sz = max (_arg .type .numpy_dtype .itemsize for _arg in inputs )
209+ args_dtype = np .dtype (f"f{ args_dtype_sz } " )
210+
211+ @numba_njit (inline = "always" )
212+ def inputs_cast (x ):
213+ return x .astype (args_dtype )
214+
215+ return inputs_cast
216+
217+
169218@singledispatch
170219def numba_typify (data , dtype = None , ** kwargs ):
171220 return data
@@ -231,6 +280,22 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
231280 return generate_fallback_impl (op , node , storage_map , ** kwargs )
232281
233282
283+ @numba_funcify .register (FunctionGraph )
284+ def numba_funcify_FunctionGraph (
285+ fgraph ,
286+ node = None ,
287+ fgraph_name = "numba_funcified_fgraph" ,
288+ ** kwargs ,
289+ ):
290+ return fgraph_to_python (
291+ fgraph ,
292+ numba_funcify ,
293+ type_conversion_fn = numba_typify ,
294+ fgraph_name = fgraph_name ,
295+ ** kwargs ,
296+ )
297+
298+
234299@numba_funcify .register (OpFromGraph )
235300def numba_funcify_OpFromGraph (op , node = None , ** kwargs ):
236301 _ = kwargs .pop ("storage_map" , None )
@@ -263,22 +328,6 @@ def opfromgraph(*inputs):
263328 return opfromgraph
264329
265330
266- @numba_funcify .register (FunctionGraph )
267- def numba_funcify_FunctionGraph (
268- fgraph ,
269- node = None ,
270- fgraph_name = "numba_funcified_fgraph" ,
271- ** kwargs ,
272- ):
273- return fgraph_to_python (
274- fgraph ,
275- numba_funcify ,
276- type_conversion_fn = numba_typify ,
277- fgraph_name = fgraph_name ,
278- ** kwargs ,
279- )
280-
281-
282331@numba_funcify .register (DeepCopyOp )
283332def numba_funcify_DeepCopyOp (op , node , ** kwargs ):
284333 if isinstance (node .inputs [0 ].type , TensorType ):
@@ -296,55 +345,6 @@ def deepcopy(x):
296345 return deepcopy
297346
298347
299- @numba .extending .intrinsic
300- def direct_cast (typingctx , val , typ ):
301- if isinstance (typ , numba .types .TypeRef ):
302- casted = typ .instance_type
303- elif isinstance (typ , numba .types .DTypeSpec ):
304- casted = typ .dtype
305- else :
306- casted = typ
307-
308- sig = casted (casted , typ )
309-
310- def codegen (context , builder , signature , args ):
311- val , _ = args
312- context .nrt .incref (builder , signature .return_type , val )
313- return val
314-
315- return sig , codegen
316-
317-
318- def int_to_float_fn (inputs , out_dtype ):
319- """Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
320-
321- if (
322- all (inp .type .dtype == out_dtype for inp in inputs )
323- and np .dtype (out_dtype ).kind == "f"
324- ):
325-
326- @numba_njit (inline = "always" )
327- def inputs_cast (x ):
328- return x
329-
330- elif any (i .type .numpy_dtype .kind in "uib" for i in inputs ):
331- args_dtype = np .dtype (f"f{ out_dtype .itemsize } " )
332-
333- @numba_njit (inline = "always" )
334- def inputs_cast (x ):
335- return x .astype (args_dtype )
336-
337- else :
338- args_dtype_sz = max (_arg .type .numpy_dtype .itemsize for _arg in inputs )
339- args_dtype = np .dtype (f"f{ args_dtype_sz } " )
340-
341- @numba_njit (inline = "always" )
342- def inputs_cast (x ):
343- return x .astype (args_dtype )
344-
345- return inputs_cast
346-
347-
348348@numba_funcify .register (IfElse )
349349def numba_funcify_IfElse (op , ** kwargs ):
350350 n_outs = op .n_outs
0 commit comments