@@ -88,6 +88,44 @@ def test_argmin(x, data):
8888 ph .assert_scalar_equals ("argmin" , type_ = int , idx = out_idx , out = min_i , expected = expected )
8989
9090
91+ @pytest .mark .min_version ("2024.12" )
92+ @given (
93+ x = hh .arrays (
94+ dtype = hh .real_dtypes ,
95+ shape = hh .shapes (min_dims = 1 , min_side = 1 ),
96+ elements = {"allow_nan" : False },
97+ ),
98+ data = st .data (),
99+ )
100+ def test_count_nonzero (x , data ):
101+ kw = data .draw (
102+ hh .kwargs (
103+ axis = st .none () | st .integers (- x .ndim , max (x .ndim - 1 , 0 )),
104+ keepdims = st .booleans (),
105+ ),
106+ label = "kw" ,
107+ )
108+ keepdims = kw .get ("keepdims" , False )
109+
110+ out = xp .count_nonzero (x , ** kw )
111+
112+ ph .assert_default_index ("count_nonzero" , out .dtype )
113+ axes = sh .normalize_axis (kw .get ("axis" , None ), x .ndim )
114+ ph .assert_keepdimable_shape (
115+ "count_nonzero" , in_shape = x .shape , out_shape = out .shape , axes = axes , keepdims = keepdims , kw = kw
116+ )
117+ scalar_type = dh .get_scalar_type (x .dtype )
118+
119+ for indices , out_idx in zip (sh .axes_ndindex (x .shape , axes ), sh .ndindex (out .shape )):
120+ count = int (out [out_idx ])
121+ elements = []
122+ for idx in indices :
123+ s = scalar_type (x [idx ])
124+ elements .append (s )
125+ expected = sum (el != 0 for el in elements )
126+ ph .assert_scalar_equals ("count_nonzero" , type_ = int , idx = out_idx , out = count , expected = expected )
127+
128+
91129@given (hh .arrays (dtype = hh .all_dtypes , shape = ()))
92130def test_nonzero_zerodim_error (x ):
93131 with pytest .raises (Exception ):
0 commit comments