@@ -3488,6 +3488,17 @@ def func(x):
34883488 return tf .identity (picks , name = _TFOUTPUT )
34893489 self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
34903490
3491+ @check_opset_min_version (9 , "IsNaN" )
3492+ def test_where_ismulinf (self ):
3493+ x_val1 = np .array ([np .inf ], dtype = np .float32 )
3494+ x_val2 = np .array ([0 ], dtype = np .float32 )
3495+ true_result = np .array ([np .inf ], dtype = np .float32 )
3496+ def func (x1 , x2 ):
3497+ mul = tf .multiply (x1 , x2 )
3498+ picks = tf .where (x1 < mul , true_result , x2 )
3499+ return tf .identity (picks , name = _TFOUTPUT )
3500+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val1 , _INPUT1 : x_val2 })
3501+
34913502 @check_opset_min_version (9 , "Where for strings needs opset 9" )
34923503 @skip_tfjs ("Technically tf where doesn't support strings and tfjs doesn't like it" )
34933504 def test_where_string (self ):
@@ -5542,7 +5553,7 @@ def func(x):
55425553 self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val0 }, rtol = 1e-6 , atol = 1e-4 )
55435554 self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-6 , atol = 1e-4 )
55445555
5545- x_val = np .random .random (size = [4 , 3 ]).astype (np .float32 ) * 2048. - 1024
5556+ x_val = np .random .random (size = [4 , 3 ]).astype (np .float32 ) * 2048. - 1024.
55465557 x_val [0 , 0 ] = - 1024
55475558 x_val [0 , 1 ] = - 1023
55485559 x_val [0 , 2 ] = 1024
@@ -5579,7 +5590,7 @@ def func(x):
55795590 self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val0 }, rtol = 1e-6 , atol = 1e-4 )
55805591 self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-6 , atol = 1e-4 )
55815592
5582- x_val = np .random .random (size = [4 , 3 ]).astype (np .float32 ) * 2048. - 1024
5593+ x_val = np .random .random (size = [4 , 3 ]).astype (np .float32 ) * 2048. - 1024.
55835594 x_val [0 , 0 ] = - 1024
55845595 x_val [0 , 1 ] = - 1023
55855596 x_val [0 , 2 ] = 1024
0 commit comments