@@ -60,3 +60,63 @@ def test_take(x, data):
6060 # sanity check
6161 with pytest .raises (StopIteration ):
6262 next (out_indices )
63+
64+
65+
66+ @pytest .mark .unvectorized
67+ @pytest .mark .min_version ("2024.12" )
68+ @given (
69+ x = hh .arrays (hh .all_dtypes , hh .shapes (min_dims = 1 , min_side = 1 )),
70+ data = st .data (),
71+ )
72+ def test_take_along_axis (x , data ):
73+ # TODO
74+ # 2. negative indices
75+ # 3. different dtypes for indices
76+ # 4. "broadcast-compatible" indices
77+ axis = data .draw (
78+ st .integers (- x .ndim , max (x .ndim - 1 , 0 )) | st .none (),
79+ label = "axis"
80+ )
81+ if axis is None :
82+ axis_kw = {}
83+ n_axis = x .ndim - 1
84+ else :
85+ axis_kw = {"axis" : axis }
86+ n_axis = axis + x .ndim if axis < 0 else axis
87+
88+ new_len = data .draw (st .integers (0 , 2 * x .shape [n_axis ]), label = "new_len" )
89+ idx_shape = x .shape [:n_axis ] + (new_len ,) + x .shape [n_axis + 1 :]
90+ indices = data .draw (
91+ hh .arrays (
92+ shape = idx_shape ,
93+ dtype = dh .default_int ,
94+ elements = {"min_value" : 0 , "max_value" : x .shape [n_axis ]- 1 }
95+ ),
96+ label = "indices"
97+ )
98+ note (f"{ indices = } { idx_shape = } " )
99+
100+ out = xp .take_along_axis (x , indices , ** axis_kw )
101+
102+ ph .assert_dtype ("take_along_axis" , in_dtype = x .dtype , out_dtype = out .dtype )
103+ ph .assert_shape (
104+ "take_along_axis" ,
105+ out_shape = out .shape ,
106+ expected = x .shape [:n_axis ] + (new_len ,) + x .shape [n_axis + 1 :],
107+ kw = dict (
108+ x = x ,
109+ indices = indices ,
110+ axis = axis ,
111+ ),
112+ )
113+
114+ # value test: notation is from `np.take_along_axis` docstring
115+ Ni , Nk = x .shape [:n_axis ], x .shape [n_axis + 1 :]
116+ for ii in sh .ndindex (Ni ):
117+ for kk in sh .ndindex (Nk ):
118+ a_1d = x [ii + (slice (None ),) + kk ]
119+ i_1d = indices [ii + (slice (None ),) + kk ]
120+ o_1d = out [ii + (slice (None ),) + kk ]
121+ for j in range (new_len ):
122+ assert o_1d [j ] == a_1d [i_1d [j ]], f'{ ii = } , { kk = } , { j = } '
0 commit comments