@@ -349,7 +349,8 @@ def test_repeat(x, kw, data):
349349 start = end
350350
351351@st .composite
352- def reshape_shapes (draw , shape ):
352+ def reshape_shapes (draw , shapes ):
353+ shape = draw (shapes )
353354 size = 1 if len (shape ) == 0 else math .prod (shape )
354355 rshape = draw (st .lists (st .integers (0 )).filter (lambda s : math .prod (s ) == size ))
355356 assume (all (side <= MAX_SIDE for side in rshape ))
@@ -359,15 +360,14 @@ def reshape_shapes(draw, shape):
359360 return tuple (rshape )
360361
361362
363+ reshape_shape = st .shared (hh .shapes (max_side = MAX_SIDE ), key = "reshape_shape" )
364+
362365@pytest .mark .unvectorized
363- @pytest .mark .skip ("flaky" ) # TODO: fix!
364366@given (
365- x = hh .arrays (dtype = hh .all_dtypes , shape = hh . shapes ( max_side = MAX_SIDE ) ),
366- data = st . data ( ),
367+ x = hh .arrays (dtype = hh .all_dtypes , shape = reshape_shape ),
368+ shape = reshape_shapes ( reshape_shape ),
367369)
368- def test_reshape (x , data ):
369- shape = data .draw (reshape_shapes (x .shape ))
370-
370+ def test_reshape (x , shape ):
371371 out = xp .reshape (x , shape )
372372
373373 ph .assert_dtype ("reshape" , in_dtype = x .dtype , out_dtype = out .dtype )
0 commit comments