@@ -141,17 +141,16 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
141141 )(scalar_op_fn )
142142
143143
144- @numba_basic .numba_njit
145- def switch (condition , x , y ):
146- if condition :
147- return x
148- else :
149- return y
150-
151-
152144@numba_funcify .register (Switch )
153145def numba_funcify_Switch (op , node , ** kwargs ):
154- return numba_basic .global_numba_func (switch )
146+ @numba_basic .numba_njit
147+ def switch (condition , x , y ):
148+ if condition :
149+ return x
150+ else :
151+ return y
152+
153+ return switch
155154
156155
157156def binary_to_nary_func (inputs : list [Variable ], binary_op_name : str , binary_op : str ):
@@ -197,34 +196,32 @@ def cast(x):
197196 return cast
198197
199198
200- @numba_basic .numba_njit
201- def identity (x ):
202- return x
203-
204-
205199@numba_funcify .register (Identity )
206200@numba_funcify .register (TypeCastingOp )
207201def numba_funcify_type_casting (op , ** kwargs ):
208- return numba_basic .global_numba_func (identity )
209-
210-
211- @numba_basic .numba_njit
212- def clip (_x , _min , _max ):
213- x = numba_basic .to_scalar (_x )
214- _min_scalar = numba_basic .to_scalar (_min )
215- _max_scalar = numba_basic .to_scalar (_max )
216-
217- if x < _min_scalar :
218- return _min_scalar
219- elif x > _max_scalar :
220- return _max_scalar
221- else :
202+ @numba_basic .numba_njit
203+ def identity (x ):
222204 return x
223205
206+ return identity
207+
224208
225209@numba_funcify .register (Clip )
226210def numba_funcify_Clip (op , ** kwargs ):
227- return numba_basic .global_numba_func (clip )
211+ @numba_basic .numba_njit
212+ def clip (x , min_val , max_val ):
213+ x = numba_basic .to_scalar (x )
214+ min_scalar = numba_basic .to_scalar (min_val )
215+ max_scalar = numba_basic .to_scalar (max_val )
216+
217+ if x < min_scalar :
218+ return min_scalar
219+ elif x > max_scalar :
220+ return max_scalar
221+ else :
222+ return x
223+
224+ return clip
228225
229226
230227@numba_funcify .register (Composite )
@@ -239,79 +236,72 @@ def numba_funcify_Composite(op, node, **kwargs):
239236 return composite_fn
240237
241238
242- @numba_basic .numba_njit
243- def second (x , y ):
244- return y
245-
246-
247239@numba_funcify .register (Second )
248240def numba_funcify_Second (op , node , ** kwargs ):
249- return numba_basic .global_numba_func (second )
250-
241+ @numba_basic .numba_njit
242+ def second (x , y ):
243+ return y
251244
252- @numba_basic .numba_njit
253- def reciprocal (x ):
254- # TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
255- # `x` is an `int`
256- return 1 / x
245+ return second
257246
258247
259248@numba_funcify .register (Reciprocal )
260249def numba_funcify_Reciprocal (op , node , ** kwargs ):
261- return numba_basic .global_numba_func (reciprocal )
262-
250+ @numba_basic .numba_njit
251+ def reciprocal (x ):
252+ # TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
253+ # `x` is an `int`
254+ return 1 / x
263255
264- @numba_basic .numba_njit
265- def sigmoid (x ):
266- return 1 / (1 + np .exp (- x ))
256+ return reciprocal
267257
268258
269259@numba_funcify .register (Sigmoid )
270260def numba_funcify_Sigmoid (op , node , ** kwargs ):
271- return numba_basic .global_numba_func (sigmoid )
272-
261+ @numba_basic .numba_njit
262+ def sigmoid (x ):
263+ return 1 / (1 + np .exp (- x ))
273264
274- @numba_basic .numba_njit
275- def gammaln (x ):
276- return math .lgamma (x )
265+ return sigmoid
277266
278267
279268@numba_funcify .register (GammaLn )
280269def numba_funcify_GammaLn (op , node , ** kwargs ):
281- return numba_basic .global_numba_func (gammaln )
282-
270+ @numba_basic .numba_njit
271+ def gammaln (x ):
272+ return math .lgamma (x )
283273
284- @numba_basic .numba_njit
285- def logp1mexp (x ):
286- if x < np .log (0.5 ):
287- return np .log1p (- np .exp (x ))
288- else :
289- return np .log (- np .expm1 (x ))
274+ return gammaln
290275
291276
292277@numba_funcify .register (Log1mexp )
293278def numba_funcify_Log1mexp (op , node , ** kwargs ):
294- return numba_basic .global_numba_func (logp1mexp )
295-
279+ @numba_basic .numba_njit
280+ def logp1mexp (x ):
281+ if x < np .log (0.5 ):
282+ return np .log1p (- np .exp (x ))
283+ else :
284+ return np .log (- np .expm1 (x ))
296285
297- @numba_basic .numba_njit
298- def erf (x ):
299- return math .erf (x )
286+ return logp1mexp
300287
301288
302289@numba_funcify .register (Erf )
303290def numba_funcify_Erf (op , ** kwargs ):
304- return numba_basic .global_numba_func (erf )
305-
291+ @numba_basic .numba_njit
292+ def erf (x ):
293+ return math .erf (x )
306294
307- @numba_basic .numba_njit
308- def erfc (x ):
309- return math .erfc (x )
295+ return erf
310296
311297
312298@numba_funcify .register (Erfc )
313299def numba_funcify_Erfc (op , ** kwargs ):
314- return numba_basic .global_numba_func (erfc )
300+ @numba_basic .numba_njit
301+ def erfc (x ):
302+ return math .erfc (x )
303+
304+ return erfc
315305
316306
317307@numba_funcify .register (Softplus )
0 commit comments