@@ -46,26 +46,27 @@ def assert_array_ndindex(
4646 assert out [out_idx ] == x [x_idx ], msg
4747
4848
49- @st .composite
50- def concat_shapes (draw , shape , axis ):
51- shape = list (shape )
52- shape [axis ] = draw (st .integers (1 , MAX_SIDE ))
53- return tuple (shape )
54-
55-
5649@given (
5750 dtypes = hh .mutually_promotable_dtypes (None , dtypes = dh .numeric_dtypes ),
58- kw = hh .kwargs ( axis = st . none () | st . integers ( - MAX_DIMS , MAX_DIMS - 1 ) ),
51+ base_shape = hh .shapes ( ),
5952 data = st .data (),
6053)
61- def test_concat (dtypes , kw , data ):
54+ def test_concat (dtypes , base_shape , data ):
55+ axis_strat = st .none ()
56+ ndim = len (base_shape )
57+ if ndim > 0 :
58+ axis_strat |= st .integers (- ndim , ndim - 1 )
59+ kw = data .draw (
60+ axis_strat .flatmap (lambda a : hh .specified_kwargs (("axis" , a , 0 ))), label = "kw"
61+ )
6262 axis = kw .get ("axis" , 0 )
6363 if axis is None :
64+ _axis = None
6465 shape_strat = hh .shapes ()
6566 else :
66- _axis = axis if axis >= 0 else abs ( axis ) - 1
67- shape_strat = shared_shapes ( min_dims = _axis + 1 ). flatmap (
68- lambda s : concat_shapes ( s , axis )
67+ _axis = axis if axis >= 0 else len ( base_shape ) + axis
68+ shape_strat = st . integers ( 0 , MAX_SIDE ). map (
69+ lambda i : base_shape [: _axis ] + ( i ,) + base_shape [ _axis + 1 :]
6970 )
7071 arrays = []
7172 for i , dtype in enumerate (dtypes , 1 ):
@@ -77,18 +78,17 @@ def test_concat(dtypes, kw, data):
7778 ph .assert_dtype ("concat" , dtypes , out .dtype )
7879
7980 shapes = tuple (x .shape for x in arrays )
80- axis = kw .get ("axis" , 0 )
81- if axis is None :
81+ if _axis is None :
8282 size = sum (math .prod (s ) for s in shapes )
8383 shape = (size ,)
8484 else :
8585 shape = list (shapes [0 ])
8686 for other_shape in shapes [1 :]:
87- shape [axis ] += other_shape [axis ]
87+ shape [_axis ] += other_shape [_axis ]
8888 shape = tuple (shape )
8989 ph .assert_result_shape ("concat" , shapes , out .shape , shape , ** kw )
9090
91- if axis is None :
91+ if _axis is None :
9292 out_indices = (i for i in range (out .size ))
9393 for x_num , x in enumerate (arrays , 1 ):
9494 for x_idx in sh .ndindex (x .shape ):
@@ -291,7 +291,7 @@ def test_roll(x, data):
291291 else :
292292 axis_strat = st .none ()
293293 if x .ndim != 0 :
294- axis_strat = axis_strat | st .integers (- x .ndim , x .ndim - 1 )
294+ axis_strat |= st .integers (- x .ndim , x .ndim - 1 )
295295 kw_strat = hh .kwargs (axis = axis_strat )
296296 kw = data .draw (kw_strat , label = "kw" )
297297
0 commit comments