44import cmath
55import math
66import operator
7+ import builtins
78from copy import copy
89from enum import Enum , auto
910from typing import Callable , List , NamedTuple , Optional , Sequence , TypeVar , Union
@@ -932,25 +933,31 @@ def test_clip(x, data):
932933
933934 # Ensure that if both min and max are arrays that all three of x, min, max
934935 # are broadcast compatible.
935- shape1 , shape2 = data .draw (hh .mutually_broadcastable_shapes (2 , base_shape = x .shape ))
936+ shape1 , shape2 = data .draw (hh .mutually_broadcastable_shapes (2 ,
937+ base_shape = x .shape ),
938+ label = "min.shape, max.shape" )
936939
937940 dtypes = hh .real_floating_dtypes if dh .is_float_dtype (x .dtype ) else hh .int_dtypes
938941
939942 min = data .draw (st .one_of (
940943 st .none (),
941944 hh .scalars (dtypes = st .just (x .dtype )),
942945 hh .arrays (dtype = dtypes , shape = shape1 ),
943- ))
946+ ), label = "min" )
944947 max = data .draw (st .one_of (
945948 st .none (),
946949 hh .scalars (dtypes = st .just (x .dtype )),
947950 hh .arrays (dtype = dtypes , shape = shape2 ),
948- ))
951+ ), label = "max" )
949952
950953 # min > max is undefined (but allow nans)
951954 assume (min is None or max is None or not xp .any (xp .asarray (min > max )))
952955
953- kw = data .draw (hh .specified_kwargs (("min" , min , None ), ("max" , max , None )))
956+ kw = data .draw (
957+ hh .specified_kwargs (
958+ ("min" , min , None ),
959+ ("max" , max , None )),
960+ label = "kwargs" )
954961
955962 out = xp .clip (x , ** kw )
956963
@@ -967,13 +974,13 @@ def test_clip(x, data):
967974
968975 if min is max is None :
969976 ph .assert_array_elements ("clip" , out = out , expected = x )
970- elif min is not None :
977+ elif max is None :
971978 # If one operand is nan, the result is nan. See
972979 # https://github.com/data-apis/array-api/pull/813.
973980 def refimpl (_x , _min ):
974981 if math .isnan (_x ) or math .isnan (_min ):
975982 return math .nan
976- return max (_x , _min )
983+ return builtins . max (_x , _min )
977984 if dh .is_scalar (min ):
978985 right_scalar_assert_against_refimpl (
979986 "clip" , x , min , out , refimpl ,
@@ -986,11 +993,11 @@ def refimpl(_x, _min):
986993 left_sym = "x" , right_sym = "min" ,
987994 expr_template = "clip({}, min={})" ,
988995 )
989- elif max is not None :
996+ elif min is None :
990997 def refimpl (_x , _max ):
991998 if math .isnan (_x ) or math .isnan (_max ):
992999 return math .nan
993- return min (_x , _max )
1000+ return builtins . min (_x , _max )
9941001 if dh .is_scalar (max ):
9951002 right_scalar_assert_against_refimpl (
9961003 "clip" , x , max , out , refimpl ,
@@ -1007,7 +1014,7 @@ def refimpl(_x, _max):
10071014 def refimpl (_x , _min , _max ):
10081015 if math .isnan (_x ) or math .isnan (_min ) or math .isnan (_max ):
10091016 return math .nan
1010- return min (max (_x , _min ), _max )
1017+ return builtins . min (builtins . max (_x , _min ), _max )
10111018
10121019 # This is based on right_scalar_assert_against_refimpl and
10131020 # binary_assert_against_refimpl. clip() is currently the only ternary
@@ -1016,8 +1023,28 @@ def refimpl(_x, _min, _max):
10161023 # and if scalar support is added to it, we may want to factor out and
10171024 # reuse this logic.
10181025
1019- # TODO
1020-
1026+ stype = dh .get_scalar_type (x .dtype )
1027+ min_shape = () if dh .is_scalar (min ) else min .shape
1028+ max_shape = () if dh .is_scalar (max ) else max .shape
1029+
1030+ for x_idx , min_idx , max_idx , o_idx in sh .iter_indices (
1031+ x .shape , min_shape , max_shape , out .shape ):
1032+ x_val = stype (x [x_idx ])
1033+ min_val = min if dh .is_scalar (min ) else min [min_idx ]
1034+ min_val = stype (min_val )
1035+ max_val = max if dh .is_scalar (max ) else max [max_idx ]
1036+ max_val = stype (max_val )
1037+ expected = refimpl (x_val , min_val , max_val )
1038+ if math .isnan (expected ):
1039+ assert math .isnan (out [o_idx ]), (
1040+ f"out[{ o_idx } ]={ out [o_idx ]} but should be nan [clip()]\n "
1041+ f"x[{ x_idx } ]={ x_val } , min[{ min_idx } ]={ min_val } , max[{ max_idx } ]={ max_val } "
1042+ )
1043+ else :
1044+ assert out [o_idx ] == expected , (
1045+ f"out[{ o_idx } ]={ out [o_idx ]} but should be { expected } [clip()]\n "
1046+ f"x[{ x_idx } ]={ x_val } , min[{ min_idx } ]={ min_val } , max[{ max_idx } ]={ max_val } "
1047+ )
10211048if api_version >= "2022.12" :
10221049
10231050 @given (hh .arrays (dtype = hh .complex_dtypes , shape = hh .shapes ()))
0 commit comments