@@ -46,49 +46,50 @@ def assert_array_ndindex(
4646 assert out [out_idx ] == x [x_idx ], msg
4747
4848
49- @st .composite
50- def concat_shapes (draw , shape , axis ):
51- shape = list (shape )
52- shape [axis ] = draw (st .integers (1 , MAX_SIDE ))
53- return tuple (shape )
54-
55-
5649@given (
5750 dtypes = hh .mutually_promotable_dtypes (None , dtypes = dh .numeric_dtypes ),
58- kw = hh . kwargs ( axis = st .none () | st .integers (- MAX_DIMS , MAX_DIMS - 1 ) ),
51+ _axis = st .none () | st .integers (0 , MAX_DIMS - 1 ),
5952 data = st .data (),
6053)
61- def test_concat (dtypes , kw , data ):
62- axis = kw .get ("axis" , 0 )
63- if axis is None :
54+ def test_concat (dtypes , _axis , data ):
55+ if _axis is None :
6456 shape_strat = hh .shapes ()
57+ axis_strat = st .none ()
6558 else :
66- any_side_axis = axis if axis >= 0 else abs (axis ) - 1
67- shape_strat = shared_shapes (min_dims = any_side_axis + 1 ).flatmap (
68- lambda s : concat_shapes (s , any_side_axis )
59+ base_shape = data .draw (
60+ hh .shapes (min_dims = _axis + 1 ).map (
61+ lambda t : t [:_axis ] + (None ,) + t [_axis + 1 :]
62+ ),
63+ label = "base shape" ,
64+ )
65+ shape_strat = st .integers (0 , MAX_SIDE ).map (
66+ lambda i : base_shape [:_axis ] + (i ,) + base_shape [_axis + 1 :]
6967 )
68+ axis_strat = st .sampled_from ([_axis , _axis - len (base_shape )])
7069 arrays = []
7170 for i , dtype in enumerate (dtypes , 1 ):
7271 x = data .draw (xps .arrays (dtype = dtype , shape = shape_strat ), label = f"x{ i } " )
7372 arrays .append (x )
73+ kw = data .draw (
74+ axis_strat .flatmap (lambda a : hh .specified_kwargs (("axis" , a , 0 ))), label = "kw"
75+ )
7476
7577 out = xp .concat (arrays , ** kw )
7678
7779 ph .assert_dtype ("concat" , dtypes , out .dtype )
7880
7981 shapes = tuple (x .shape for x in arrays )
80- axis = kw .get ("axis" , 0 )
81- if axis is None :
82+ if _axis is None :
8283 size = sum (math .prod (s ) for s in shapes )
8384 shape = (size ,)
8485 else :
8586 shape = list (shapes [0 ])
8687 for other_shape in shapes [1 :]:
87- shape [axis ] += other_shape [axis ]
88+ shape [_axis ] += other_shape [_axis ]
8889 shape = tuple (shape )
8990 ph .assert_result_shape ("concat" , shapes , out .shape , shape , ** kw )
9091
91- if axis is None :
92+ if _axis is None :
9293 out_indices = (i for i in range (out .size ))
9394 for x_num , x in enumerate (arrays , 1 ):
9495 for x_idx in sh .ndindex (x .shape ):
@@ -102,8 +103,6 @@ def test_concat(dtypes, kw, data):
102103 ** kw ,
103104 )
104105 else :
105- ndim = len (shapes [0 ])
106- _axis = axis if axis >= 0 else ndim - 1
107106 out_indices = sh .ndindex (out .shape )
108107 for idx in sh .axis_ndindex (shapes [0 ], _axis ):
109108 f_idx = ", " .join (str (i ) if isinstance (i , int ) else ":" for i in idx )
0 commit comments