@@ -924,10 +924,30 @@ def test_ceil(x):
924924
925925
926926@pytest .mark .min_version ("2023.12" )
927- @given (hh .arrays (dtype = hh .real_floating_dtypes , shape = hh .shapes ()))
928- def test_clip (x ):
927+ @given (x = hh .arrays (dtype = hh .real_floating_dtypes , shape = hh .shapes ()), data = st . data ( ))
928+ def test_clip (x , data ):
929929 # TODO: test min/max kwargs, adjust values testing accordingly
930- out = xp .clip (x )
930+
931+ # Ensure that if both min and max are arrays that all three of x, min, max
932+ # are broadcast compatible.
933+ shape1 , shape2 = data .draw (hh .mutually_broadcastable_shapes (2 , base_shape = x .shape ))
934+
935+ dtypes = hh .real_floating_dtypes if dh .is_float_dtype (x .dtype ) else hh .int_dtypes
936+
937+ min = data .draw (st .one_of (
938+ st .none (),
939+ hh .scalars (dtypes = st .just (x .dtype )),
940+ hh .arrays (dtype = dtypes , shape = shape1 ),
941+ ))
942+ max = data .draw (st .one_of (
943+ st .none (),
944+ hh .scalars (dtypes = st .just (x .dtype )),
945+ hh .arrays (dtype = dtypes , shape = shape2 ),
946+ ))
947+
948+ kw = data .draw (hh .specified_kwargs (("min" , min , None ), ("max" , max , None )))
949+
950+ out = xp .clip (x , ** kw )
931951 ph .assert_dtype ("clip" , in_dtype = x .dtype , out_dtype = out .dtype )
932952 ph .assert_shape ("clip" , out_shape = out .shape , expected = x .shape )
933953 ph .assert_array_elements ("clip" , out = out , expected = x )
0 commit comments