@@ -744,6 +744,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
744744 axis = 0
745745 return torch .index_select (x , axis , indices , ** kwargs )
746746
747+
748+ def take_along_axis (x : array , indices : array , / , * , axis : int = - 1 ) -> array :
749+ return torch .take_along_dim (x , indices , dim = axis )
750+
751+
747752def sign (x : array , / ) -> array :
748753 # torch sign() does not support complex numbers and does not propagate
749754 # nans. See https://github.com/data-apis/array-api-compat/issues/136
@@ -767,14 +772,14 @@ def sign(x: array, /) -> array:
767772 'equal' , 'floor_divide' , 'greater' , 'greater_equal' , 'hypot' ,
768773 'less' , 'less_equal' , 'logaddexp' , 'maximum' , 'minimum' ,
769774 'multiply' , 'not_equal' , 'pow' , 'remainder' , 'subtract' , 'max' ,
770- 'min' , 'clip' , 'unstack' , 'cumulative_sum' , 'sort' , 'prod' , 'sum' ,
775+ 'min' , 'clip' , 'unstack' , 'cumulative_sum' , 'cumulative_prod' , ' sort' , 'prod' , 'sum' ,
771776 'any' , 'all' , 'mean' , 'std' , 'var' , 'concat' , 'squeeze' ,
772777 'broadcast_to' , 'flip' , 'roll' , 'nonzero' , 'where' , 'reshape' ,
773778 'arange' , 'eye' , 'linspace' , 'full' , 'ones' , 'zeros' , 'empty' ,
774779 'tril' , 'triu' , 'expand_dims' , 'astype' , 'broadcast_arrays' ,
775780 'UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
776781 'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
777782 'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' , 'isdtype' ,
778- 'take' , 'sign' ]
783+ 'take' , 'take_along_axis' , ' sign' ]
779784
780785_all_ignore = ['torch' , 'get_xp' ]
0 commit comments