@@ -947,13 +947,76 @@ def test_clip(x, data):
947947 hh .arrays (dtype = dtypes , shape = shape2 ),
948948 ))
949949
950+ # min > max is undefined (but allow nans)
951+ assume (min is None or max is None or not xp .any (xp .asarray (min > max )))
952+
950953 kw = data .draw (hh .specified_kwargs (("min" , min , None ), ("max" , max , None )))
951954
952955 out = xp .clip (x , ** kw )
956+
957+ # min and max do not participate in type promotion
953958 ph .assert_dtype ("clip" , in_dtype = x .dtype , out_dtype = out .dtype )
954- ph .assert_shape ("clip" , out_shape = out .shape , expected = x .shape )
955- ph .assert_array_elements ("clip" , out = out , expected = x )
956959
960+ shapes = [x .shape ]
961+ if min is not None and not dh .is_scalar (min ):
962+ shapes .append (min .shape )
963+ if max is not None and not dh .is_scalar (max ):
964+ shapes .append (max .shape )
965+ expected_shape = sh .broadcast_shapes (* shapes )
966+ ph .assert_shape ("clip" , out_shape = out .shape , expected = expected_shape )
967+
968+ if min is max is None :
969+ ph .assert_array_elements ("clip" , out = out , expected = x )
970+ elif min is not None :
971+ # If one operand is nan, the result is nan. See
972+ # https://github.com/data-apis/array-api/pull/813.
973+ def refimpl (_x , _min ):
974+ if math .isnan (_x ) or math .isnan (_min ):
975+ return math .nan
976+ return max (_x , _min )
977+ if dh .is_scalar (min ):
978+ right_scalar_assert_against_refimpl (
979+ "clip" , x , min , out , refimpl ,
980+ left_sym = "x" ,
981+ expr_template = "clip({}, min={})" ,
982+ )
983+ else :
984+ binary_assert_against_refimpl (
985+ "clip" , x , min , out , refimpl ,
986+ left_sym = "x" , right_sym = "min" ,
987+ expr_template = "clip({}, min={})" ,
988+ )
989+ elif max is not None :
990+ def refimpl (_x , _max ):
991+ if math .isnan (_x ) or math .isnan (_max ):
992+ return math .nan
993+ return min (_x , _max )
994+ if dh .is_scalar (max ):
995+ right_scalar_assert_against_refimpl (
996+ "clip" , x , max , out , refimpl ,
997+ left_sym = "x" ,
998+ expr_template = "clip({}, max={})" ,
999+ )
1000+ else :
1001+ binary_assert_against_refimpl (
1002+ "clip" , x , max , out , refimpl ,
1003+ left_sym = "x" , right_sym = "max" ,
1004+ expr_template = "clip({}, max={})" ,
1005+ )
1006+ else :
1007+ def refimpl (_x , _min , _max ):
1008+ if math .isnan (_x ) or math .isnan (_min ) or math .isnan (_max ):
1009+ return math .nan
1010+ return min (max (_x , _min ), _max )
1011+
1012+ # This is based on right_scalar_assert_against_refimpl and
1013+ # binary_assert_against_refimpl. clip() is currently the only ternary
1014+ # elementwise function and the only function that supports arrays and
1015+ # scalars. However, where() (in test_searching_functions) is similar
1016+ # and if scalar support is added to it, we may want to factor out and
1017+ # reuse this logic.
1018+
1019+ # TODO
9571020
9581021if api_version >= "2022.12" :
9591022
0 commit comments