@@ -681,6 +681,22 @@ def largest(*args):
681681 return max (stack (args ), axis = 0 )
682682
683683
684+ def isposinf (x ):
685+ """
686+ Return if the input variable has positive infinity element
687+
688+ """
689+ return eq (x , np .inf )
690+
691+
692+ def isneginf (x ):
693+ """
694+ Return if the input variable has negative infinity element
695+
696+ """
697+ return eq (x , - np .inf )
698+
699+
684700@scalar_elemwise
685701def lt (a , b ):
686702 """a < b"""
@@ -2913,6 +2929,62 @@ def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
29132929 return vectorize_node_fallback (op , node , batched_x , batched_y )
29142930
29152931
2932+ def nan_to_num (x , nan = 0.0 , posinf = None , neginf = None ):
2933+ """
2934+ Replace NaN with zero and infinity with large finite numbers (default
2935+ behaviour) or with the numbers defined by the user using the `nan`,
2936+ `posinf` and/or `neginf` keywords.
2937+
2938+ NaN is replaced by zero or by the user defined value in
2939+ `nan` keyword, infinity is replaced by the largest finite floating point
2940+ values representable by ``x.dtype`` or by the user defined value in
2941+ `posinf` keyword and -infinity is replaced by the most negative finite
2942+ floating point values representable by ``x.dtype`` or by the user defined
2943+ value in `neginf` keyword.
2944+
2945+ Parameters
2946+ ----------
2947+ x : symbolic tensor
2948+ Input array.
2949+ nan
2950+ The value to replace NaN's with in the tensor (default = 0).
2951+ posinf
2952+ The value to replace +INF with in the tensor (default max
2953+ in range representable by ``x.dtype``).
2954+ neginf
2955+ The value to replace -INF with in the tensor (default min
2956+ in range representable by ``x.dtype``).
2957+
2958+ Returns
2959+ -------
2960+ out
2961+ The tensor with NaN's, +INF, and -INF replaced with the
2962+ specified and/or default substitutions.
2963+ """
2964+ # Replace NaN's with nan keyword
2965+ is_nan = isnan (x )
2966+ is_pos_inf = isposinf (x )
2967+ is_neg_inf = isneginf (x )
2968+
2969+ x = switch (is_nan , nan , x )
2970+
2971+ # Get max and min values representable by x.dtype
2972+ maxf = posinf
2973+ minf = neginf
2974+
2975+ # Specify the value to replace +INF and -INF with
2976+ if maxf is None :
2977+ maxf = np .finfo (x .real .dtype ).max
2978+ if minf is None :
2979+ minf = np .finfo (x .real .dtype ).min
2980+
2981+ # Replace +INF and -INF values
2982+ x = switch (is_pos_inf , maxf , x )
2983+ x = switch (is_neg_inf , minf , x )
2984+
2985+ return x
2986+
2987+
29162988# NumPy logical aliases
29172989square = sqr
29182990
@@ -2951,6 +3023,8 @@ def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
29513023 "not_equal" ,
29523024 "isnan" ,
29533025 "isinf" ,
3026+ "isposinf" ,
3027+ "isneginf" ,
29543028 "allclose" ,
29553029 "isclose" ,
29563030 "and_" ,
@@ -3069,4 +3143,5 @@ def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
30693143 "logaddexp" ,
30703144 "logsumexp" ,
30713145 "hyp2f1" ,
3146+ "nan_to_num" ,
30723147]
0 commit comments