@@ -970,80 +970,65 @@ def test_clip(x, data):
970970 expected_shape = sh .broadcast_shapes (* shapes )
971971 ph .assert_shape ("clip" , out_shape = out .shape , expected = expected_shape )
972972
973- if min is max is None :
974- ph .assert_array_elements ("clip" , out = out , expected = x )
975- elif max is None :
976- # If one operand is nan, the result is nan. See
977- # https://github.com/data-apis/array-api/pull/813.
978- def refimpl (_x , _min ):
979- if math .isnan (_x ) or math .isnan (_min ):
980- return math .nan
973+ # This is based on right_scalar_assert_against_refimpl and
974+ # binary_assert_against_refimpl. clip() is currently the only ternary
975+ # elementwise function and the only function that supports arrays and
976+ # scalars. However, where() (in test_searching_functions) is similar
977+ # and if scalar support is added to it, we may want to factor out and
978+ # reuse this logic.
979+
980+ def refimpl (_x , _min , _max ):
981+ # Skip cases where _min and _max are integers whose values do not
982+ # fit in the dtype of _x, since this behavior is unspecified.
983+ if dh .is_int_dtype (x .dtype ):
984+ if _min is not None and _min not in dh .dtype_ranges [x .dtype ]:
985+ return None
986+ if _max is not None and _max not in dh .dtype_ranges [x .dtype ]:
987+ return None
988+
989+ if (math .isnan (_x )
990+ or (_min is not None and math .isnan (_min ))
991+ or (_max is not None and math .isnan (_max ))):
992+ return math .nan
993+ if _min is _max is None :
994+ return _x
995+ if _max is None :
981996 return builtins .max (_x , _min )
982- if dh .is_scalar (min ):
983- right_scalar_assert_against_refimpl (
984- "clip" , x , min , out , refimpl ,
985- left_sym = "x" ,
986- expr_template = "clip({}, min={})" ,
987- )
988- else :
989- binary_assert_against_refimpl (
990- "clip" , x , min , out , refimpl ,
991- left_sym = "x" , right_sym = "min" ,
992- expr_template = "clip({}, min={})" ,
993- )
994- elif min is None :
995- def refimpl (_x , _max ):
996- if math .isnan (_x ) or math .isnan (_max ):
997- return math .nan
997+ if _min is None :
998998 return builtins .min (_x , _max )
999- if dh .is_scalar (max ):
1000- right_scalar_assert_against_refimpl (
1001- "clip" , x , max , out , refimpl ,
1002- left_sym = "x" ,
1003- expr_template = "clip({}, max={})" ,
999+ return builtins .min (builtins .max (_x , _min ), _max )
1000+
1001+ stype = dh .get_scalar_type (x .dtype )
1002+ min_shape = () if min is None or dh .is_scalar (min ) else min .shape
1003+ max_shape = () if max is None or dh .is_scalar (max ) else max .shape
1004+
1005+ for x_idx , min_idx , max_idx , o_idx in sh .iter_indices (
1006+ x .shape , min_shape , max_shape , out .shape ):
1007+ x_val = stype (x [x_idx ])
1008+ if min is None or dh .is_scalar (min ):
1009+ min_val = min
1010+ else :
1011+ min_val = stype (min [min_idx ])
1012+ if max is None or dh .is_scalar (max ):
1013+ max_val = max
1014+ else :
1015+ max_val = stype (max [max_idx ])
1016+ expected = refimpl (x_val , min_val , max_val )
1017+ if expected is None :
1018+ continue
1019+ out_val = stype (out [o_idx ])
1020+ if math .isnan (expected ):
1021+ assert math .isnan (out_val ), (
1022+ f"out[{ o_idx } ]={ out [o_idx ]} but should be nan [clip()]\n "
1023+ f"x[{ x_idx } ]={ x_val } , min[{ min_idx } ]={ min_val } , max[{ max_idx } ]={ max_val } "
10041024 )
10051025 else :
1006- binary_assert_against_refimpl (
1007- "clip" , x , max , out , refimpl ,
1008- left_sym = "x" , right_sym = "max" ,
1009- expr_template = "clip({}, max={})" ,
1026+ assert out_val == expected , (
1027+ f"out[{ o_idx } ]={ out [o_idx ]} but should be { expected } [clip()]\n "
1028+ f"x[{ x_idx } ]={ x_val } , min[{ min_idx } ]={ min_val } , max[{ max_idx } ]={ max_val } "
10101029 )
1011- else :
1012- def refimpl (_x , _min , _max ):
1013- if math .isnan (_x ) or math .isnan (_min ) or math .isnan (_max ):
1014- return math .nan
1015- return builtins .min (builtins .max (_x , _min ), _max )
1016-
1017- # This is based on right_scalar_assert_against_refimpl and
1018- # binary_assert_against_refimpl. clip() is currently the only ternary
1019- # elementwise function and the only function that supports arrays and
1020- # scalars. However, where() (in test_searching_functions) is similar
1021- # and if scalar support is added to it, we may want to factor out and
1022- # reuse this logic.
1023-
1024- stype = dh .get_scalar_type (x .dtype )
1025- min_shape = () if dh .is_scalar (min ) else min .shape
1026- max_shape = () if dh .is_scalar (max ) else max .shape
1027-
1028- for x_idx , min_idx , max_idx , o_idx in sh .iter_indices (
1029- x .shape , min_shape , max_shape , out .shape ):
1030- x_val = stype (x [x_idx ])
1031- min_val = min if dh .is_scalar (min ) else min [min_idx ]
1032- min_val = stype (min_val )
1033- max_val = max if dh .is_scalar (max ) else max [max_idx ]
1034- max_val = stype (max_val )
1035- expected = refimpl (x_val , min_val , max_val )
1036- out_val = stype (out [o_idx ])
1037- if math .isnan (expected ):
1038- assert math .isnan (out_val ), (
1039- f"out[{ o_idx } ]={ out [o_idx ]} but should be nan [clip()]\n "
1040- f"x[{ x_idx } ]={ x_val } , min[{ min_idx } ]={ min_val } , max[{ max_idx } ]={ max_val } "
1041- )
1042- else :
1043- assert out_val == expected , (
1044- f"out[{ o_idx } ]={ out [o_idx ]} but should be { expected } [clip()]\n "
1045- f"x[{ x_idx } ]={ x_val } , min[{ min_idx } ]={ min_val } , max[{ max_idx } ]={ max_val } "
1046- )
1030+
1031+
10471032if api_version >= "2022.12" :
10481033
10491034 @given (hh .arrays (dtype = hh .complex_dtypes , shape = hh .shapes ()))
0 commit comments