@@ -88,6 +88,52 @@ def test_argmin(x, data):
8888 ph .assert_scalar_equals ("argmin" , type_ = int , idx = out_idx , out = min_i , expected = expected )
8989
9090
91+ # XXX: dtype= stanza below is to work around unsigned int dtypes in torch
92+ # (count_nonzero_cpu not implemented for uint32 etc)
93+ # XXX: the strategy for x is problematic on JAX unless JAX_ENABLE_X64 is on
94+ # the problem is tha for ints >iinfo(int32) it runs into essentially this:
95+ # >>> jnp.asarray[2147483648], dtype=jnp.int64)
96+ # .... https://github.com/jax-ml/jax/pull/6047 ...
97+ # Explicitly limiting the range in elements(...) runs into problems with
98+ # hypothesis where floating-point numbers are not exactly representable.
99+ @pytest .mark .min_version ("2024.12" )
100+ @given (
101+ x = hh .arrays (
102+ dtype = st .sampled_from (dh .int_dtypes + dh .real_float_dtypes + dh .complex_dtypes + (xp .bool ,)),
103+ shape = hh .shapes (min_dims = 1 , min_side = 1 ),
104+ elements = {"allow_nan" : False },
105+ ),
106+ data = st .data (),
107+ )
108+ def test_count_nonzero (x , data ):
109+ kw = data .draw (
110+ hh .kwargs (
111+ axis = st .none () | st .integers (- x .ndim , max (x .ndim - 1 , 0 )),
112+ keepdims = st .booleans (),
113+ ),
114+ label = "kw" ,
115+ )
116+ keepdims = kw .get ("keepdims" , False )
117+
118+ out = xp .count_nonzero (x , ** kw )
119+
120+ ph .assert_default_index ("count_nonzero" , out .dtype )
121+ axes = sh .normalize_axis (kw .get ("axis" , None ), x .ndim )
122+ ph .assert_keepdimable_shape (
123+ "count_nonzero" , in_shape = x .shape , out_shape = out .shape , axes = axes , keepdims = keepdims , kw = kw
124+ )
125+ scalar_type = dh .get_scalar_type (x .dtype )
126+
127+ for indices , out_idx in zip (sh .axes_ndindex (x .shape , axes ), sh .ndindex (out .shape )):
128+ count = int (out [out_idx ])
129+ elements = []
130+ for idx in indices :
131+ s = scalar_type (x [idx ])
132+ elements .append (s )
133+ expected = sum (el != 0 for el in elements )
134+ ph .assert_scalar_equals ("count_nonzero" , type_ = int , idx = out_idx , out = count , expected = expected )
135+
136+
91137@given (hh .arrays (dtype = hh .all_dtypes , shape = ()))
92138def test_nonzero_zerodim_error (x ):
93139 with pytest .raises (Exception ):
0 commit comments