44import cmath
55import math
66import operator
7+ import builtins
78from copy import copy
89from enum import Enum , auto
910from typing import Callable , List , NamedTuple , Optional , Sequence , TypeVar , Union
@@ -369,6 +370,8 @@ def right_scalar_assert_against_refimpl(
369370
370371 See unary_assert_against_refimpl for more information.
371372 """
373+ if expr_template is None :
374+ expr_template = func_name + "({}, {})={}"
372375 if left .dtype in dh .complex_dtypes :
373376 component_filter = copy (filter_ )
374377 filter_ = lambda s : component_filter (s .real ) and component_filter (s .imag )
@@ -422,7 +425,7 @@ def right_scalar_assert_against_refimpl(
422425 )
423426
424427
425- # When appropiate , this module tests operators alongside their respective
428+ # When appropriate , this module tests operators alongside their respective
426429# elementwise methods. We do this by parametrizing a generalised test method
427430# with every relevant method and operator.
428431#
@@ -432,8 +435,8 @@ def right_scalar_assert_against_refimpl(
432435# - The argument strategies, which can be used to draw arguments for the test
433436# case. They may require additional filtering for certain test cases.
434437# - right_is_scalar (binary parameters only), which denotes if the right
435- # argument is a scalar in a test case. This can be used to appropiately adjust
436- # draw filtering and test logic.
438+ # argument is a scalar in a test case. This can be used to appropriately
439+ # adjust draw filtering and test logic.
437440
438441
439442func_to_op = {v : k for k , v in dh .op_to_func .items ()}
@@ -475,7 +478,7 @@ def make_unary_params(
475478 )
476479 if api_version < min_version :
477480 marks = pytest .mark .skip (
478- reason = f"requires ARRAY_API_TESTS_VERSION=> { min_version } "
481+ reason = f"requires ARRAY_API_TESTS_VERSION >= { min_version } "
479482 )
480483 else :
481484 marks = ()
@@ -924,15 +927,125 @@ def test_ceil(x):
924927
925928
926929@pytest .mark .min_version ("2023.12" )
927- @given (hh .arrays (dtype = hh .real_floating_dtypes , shape = hh .shapes ()))
928- def test_clip (x ):
930+ @given (x = hh .arrays (dtype = hh .real_dtypes , shape = hh .shapes ()), data = st . data ( ))
931+ def test_clip (x , data ):
929932 # TODO: test min/max kwargs, adjust values testing accordingly
930- out = xp .clip (x )
931- ph .assert_dtype ("clip" , in_dtype = x .dtype , out_dtype = out .dtype )
932- ph .assert_shape ("clip" , out_shape = out .shape , expected = x .shape )
933- ph .assert_array_elements ("clip" , out = out , expected = x )
934933
934+ # Ensure that if both min and max are arrays that all three of x, min, max
935+ # are broadcast compatible.
936+ shape1 , shape2 = data .draw (hh .mutually_broadcastable_shapes (2 ,
937+ base_shape = x .shape ),
938+ label = "min.shape, max.shape" )
939+
940+ dtypes = hh .real_floating_dtypes if dh .is_float_dtype (x .dtype ) else hh .int_dtypes
941+
942+ min = data .draw (st .one_of (
943+ st .none (),
944+ hh .scalars (dtypes = st .just (x .dtype )),
945+ hh .arrays (dtype = dtypes , shape = shape1 ),
946+ ), label = "min" )
947+ max = data .draw (st .one_of (
948+ st .none (),
949+ hh .scalars (dtypes = st .just (x .dtype )),
950+ hh .arrays (dtype = dtypes , shape = shape2 ),
951+ ), label = "max" )
952+
953+ # min > max is undefined (but allow nans)
954+ assume (min is None or max is None or not xp .any (xp .asarray (min ) > xp .asarray (max )))
955+
956+ kw = data .draw (
957+ hh .specified_kwargs (
958+ ("min" , min , None ),
959+ ("max" , max , None )),
960+ label = "kwargs" )
961+
962+ out = xp .clip (x , ** kw )
963+
964+ # min and max do not participate in type promotion
965+ ph .assert_dtype ("clip" , in_dtype = x .dtype , out_dtype = out .dtype )
935966
967+ shapes = [x .shape ]
968+ if min is not None and not dh .is_scalar (min ):
969+ shapes .append (min .shape )
970+ if max is not None and not dh .is_scalar (max ):
971+ shapes .append (max .shape )
972+ expected_shape = sh .broadcast_shapes (* shapes )
973+ ph .assert_shape ("clip" , out_shape = out .shape , expected = expected_shape )
974+
975+ if min is max is None :
976+ ph .assert_array_elements ("clip" , out = out , expected = x )
977+ elif max is None :
978+ # If one operand is nan, the result is nan. See
979+ # https://github.com/data-apis/array-api/pull/813.
980+ def refimpl (_x , _min ):
981+ if math .isnan (_x ) or math .isnan (_min ):
982+ return math .nan
983+ return builtins .max (_x , _min )
984+ if dh .is_scalar (min ):
985+ right_scalar_assert_against_refimpl (
986+ "clip" , x , min , out , refimpl ,
987+ left_sym = "x" ,
988+ expr_template = "clip({}, min={})" ,
989+ )
990+ else :
991+ binary_assert_against_refimpl (
992+ "clip" , x , min , out , refimpl ,
993+ left_sym = "x" , right_sym = "min" ,
994+ expr_template = "clip({}, min={})" ,
995+ )
996+ elif min is None :
997+ def refimpl (_x , _max ):
998+ if math .isnan (_x ) or math .isnan (_max ):
999+ return math .nan
1000+ return builtins .min (_x , _max )
1001+ if dh .is_scalar (max ):
1002+ right_scalar_assert_against_refimpl (
1003+ "clip" , x , max , out , refimpl ,
1004+ left_sym = "x" ,
1005+ expr_template = "clip({}, max={})" ,
1006+ )
1007+ else :
1008+ binary_assert_against_refimpl (
1009+ "clip" , x , max , out , refimpl ,
1010+ left_sym = "x" , right_sym = "max" ,
1011+ expr_template = "clip({}, max={})" ,
1012+ )
1013+ else :
1014+ def refimpl (_x , _min , _max ):
1015+ if math .isnan (_x ) or math .isnan (_min ) or math .isnan (_max ):
1016+ return math .nan
1017+ return builtins .min (builtins .max (_x , _min ), _max )
1018+
1019+ # This is based on right_scalar_assert_against_refimpl and
1020+ # binary_assert_against_refimpl. clip() is currently the only ternary
1021+ # elementwise function and the only function that supports arrays and
1022+ # scalars. However, where() (in test_searching_functions) is similar
1023+ # and if scalar support is added to it, we may want to factor out and
1024+ # reuse this logic.
1025+
1026+ stype = dh .get_scalar_type (x .dtype )
1027+ min_shape = () if dh .is_scalar (min ) else min .shape
1028+ max_shape = () if dh .is_scalar (max ) else max .shape
1029+
1030+ for x_idx , min_idx , max_idx , o_idx in sh .iter_indices (
1031+ x .shape , min_shape , max_shape , out .shape ):
1032+ x_val = stype (x [x_idx ])
1033+ min_val = min if dh .is_scalar (min ) else min [min_idx ]
1034+ min_val = stype (min_val )
1035+ max_val = max if dh .is_scalar (max ) else max [max_idx ]
1036+ max_val = stype (max_val )
1037+ expected = refimpl (x_val , min_val , max_val )
1038+ out_val = stype (out [o_idx ])
1039+ if math .isnan (expected ):
1040+ assert math .isnan (out_val ), (
1041+ f"out[{ o_idx } ]={ out [o_idx ]} but should be nan [clip()]\n "
1042+ f"x[{ x_idx } ]={ x_val } , min[{ min_idx } ]={ min_val } , max[{ max_idx } ]={ max_val } "
1043+ )
1044+ else :
1045+ assert out_val == expected , (
1046+ f"out[{ o_idx } ]={ out [o_idx ]} but should be { expected } [clip()]\n "
1047+ f"x[{ x_idx } ]={ x_val } , min[{ min_idx } ]={ min_val } , max[{ max_idx } ]={ max_val } "
1048+ )
9361049if api_version >= "2022.12" :
9371050
9381051 @given (hh .arrays (dtype = hh .complex_dtypes , shape = hh .shapes ()))
0 commit comments