@@ -60,3 +60,55 @@ def test_take(x, data):
6060 # sanity check
6161 with pytest .raises (StopIteration ):
6262 next (out_indices )
63+
64+
65+
66+ @pytest .mark .unvectorized
67+ @pytest .mark .min_version ("2024.12" )
68+ @given (
69+ x = hh .arrays (hh .all_dtypes , hh .shapes (min_dims = 1 , min_side = 1 )),
70+ data = st .data (),
71+ )
72+ def test_take_along_axis (x , data ):
73+ # TODO
74+ # 1. negative axis
75+ # 2. negative indices
76+ # 3. different dtypes for indices
77+ axis = data .draw (st .integers (0 , max (x .ndim - 1 , 0 )), label = "axis" )
78+ len_axis = data .draw (st .integers (0 , 2 * x .shape [axis ]), label = "len_axis" )
79+
80+ idx_shape = x .shape [:axis ] + (len_axis ,) + x .shape [axis + 1 :]
81+ indices = data .draw (
82+ hh .arrays (
83+ shape = idx_shape ,
84+ dtype = dh .default_int ,
85+ elements = {"min_value" : 0 , "max_value" : x .shape [axis ]- 1 }
86+ ),
87+ label = "indices"
88+ )
89+ note (f"{ indices = } { idx_shape = } " )
90+
91+ out = xp .take_along_axis (x , indices , axis = axis )
92+
93+ ph .assert_dtype ("take_along_axis" , in_dtype = x .dtype , out_dtype = out .dtype )
94+ ph .assert_shape (
95+ "take_along_axis" ,
96+ out_shape = out .shape ,
97+ expected = x .shape [:axis ] + (len_axis ,) + x .shape [axis + 1 :],
98+ kw = dict (
99+ x = x ,
100+ indices = indices ,
101+ axis = axis ,
102+ ),
103+ )
104+
105+ # value test: notation is from `np.take_along_axis` docstring
106+ Ni , Nk = x .shape [:axis ], x .shape [axis + 1 :]
107+ for ii in sh .ndindex (Ni ):
108+ for kk in sh .ndindex (Nk ):
109+ a_1d = x [ii + (slice (None ),) + kk ]
110+ i_1d = indices [ii + (slice (None ),) + kk ]
111+ o_1d = out [ii + (slice (None ),) + kk ]
112+ for j in range (len_axis ):
113+ assert o_1d [j ] == a_1d [i_1d [j ]], f'{ ii = } , { kk = } , { j = } '
114+
0 commit comments