@@ -48,31 +48,30 @@ def assert_array_ndindex(
4848
4949@given (
5050 dtypes = hh .mutually_promotable_dtypes (None , dtypes = dh .numeric_dtypes ),
51- _axis = st . none () | st . integers ( 0 , MAX_DIMS - 1 ),
51+ base_shape = hh . shapes ( ),
5252 data = st .data (),
5353)
54- def test_concat (dtypes , _axis , data ):
55- if _axis is None :
54+ def test_concat (dtypes , base_shape , data ):
55+ axis_strat = st .none ()
56+ ndim = len (base_shape )
57+ if ndim > 0 :
58+ axis_strat |= st .integers (- ndim , ndim - 1 )
59+ kw = data .draw (
60+ axis_strat .flatmap (lambda a : hh .specified_kwargs (("axis" , a , 0 ))), label = "kw"
61+ )
62+ axis = kw .get ("axis" , 0 )
63+ if axis is None :
64+ _axis = None
5665 shape_strat = hh .shapes ()
57- axis_strat = st .none ()
5866 else :
59- base_shape = data .draw (
60- hh .shapes (min_dims = _axis + 1 ).map (
61- lambda t : t [:_axis ] + (None ,) + t [_axis + 1 :]
62- ),
63- label = "base shape" ,
64- )
67+ _axis = axis if axis >= 0 else len (base_shape ) + axis
6568 shape_strat = st .integers (0 , MAX_SIDE ).map (
6669 lambda i : base_shape [:_axis ] + (i ,) + base_shape [_axis + 1 :]
6770 )
68- axis_strat = st .sampled_from ([_axis , _axis - len (base_shape )])
6971 arrays = []
7072 for i , dtype in enumerate (dtypes , 1 ):
7173 x = data .draw (xps .arrays (dtype = dtype , shape = shape_strat ), label = f"x{ i } " )
7274 arrays .append (x )
73- kw = data .draw (
74- axis_strat .flatmap (lambda a : hh .specified_kwargs (("axis" , a , 0 ))), label = "kw"
75- )
7675
7776 out = xp .concat (arrays , ** kw )
7877
@@ -292,7 +291,7 @@ def test_roll(x, data):
292291 else :
293292 axis_strat = st .none ()
294293 if x .ndim != 0 :
295- axis_strat = axis_strat | st .integers (- x .ndim , x .ndim - 1 )
294+ axis_strat |= st .integers (- x .ndim , x .ndim - 1 )
296295 kw_strat = hh .kwargs (axis = axis_strat )
297296 kw = data .draw (kw_strat , label = "kw" )
298297
0 commit comments