@@ -34,7 +34,7 @@ defmodule EXLA.Lib do
3434 def argmax ( builder , op , type , opts \\ [ ] )
3535
3636 def argmax ( % Function { } = builder , % Value { } = op , type , opts ) do
37- argmin_or_max ( builder , op , false , type , opts )
37+ argmin_or_max ( builder , op , :max , type , opts )
3838 end
3939
4040 @ doc """
@@ -49,37 +49,43 @@ defmodule EXLA.Lib do
4949 def argmin ( builder , op , type , opts \\ [ ] )
5050
5151 def argmin ( % Function { } = builder , % Value { } = op , type , opts ) do
52- argmin_or_max ( builder , op , true , type , opts )
52+ argmin_or_max ( builder , op , :min , type , opts )
5353 end
5454
55- defp argmin_or_max ( builder , % Value { } = op , is_min? , type , opts ) do
55+ defp argmin_or_max ( builder , % Value { } = op , variant , type , opts ) do
5656 tie_break = opts [ :tie_break ] || :low
5757 keep_axis = opts [ :keep_axis ] || false
58+ axis = opts [ :axis ]
5859
5960 op_typespec = Value . get_typespec ( op )
6061
62+ { op , op_typespec } =
63+ if axis == nil and Nx . rank ( op_typespec . shape ) != 1 do
64+ # When no axis is given, we flatten the tensor and reduce over
65+ # the first axis
66+ typespec = Typespec . to_shape ( op_typespec , { Nx . size ( op_typespec . shape ) } )
67+ { Value . reshape ( op , typespec ) , typespec }
68+ else
69+ { op , op_typespec }
70+ end
71+
72+ axis = axis || 0
73+
6174 init_value =
62- if is_min? ,
63- do: max_number ( builder , op_typespec . type ) ,
64- else: min_number ( builder , op_typespec . type )
75+ case variant do
76+ :min -> max_number ( builder , op_typespec . type )
77+ :max -> min_number ( builder , op_typespec . type )
78+ end
6579
66- axis = opts [ :axis ]
6780 index_init_value = Value . constant ( builder , [ 0 ] , Typespec . tensor ( type , { } ) )
6881 iota = iota ( builder , axis , Typespec . to_type ( op_typespec , type ) )
69- reduction = create_min_max_computation ( builder , op_typespec . type , type , is_min? , tie_break )
82+ reduction = create_min_max_computation ( builder , op_typespec . type , type , variant , tie_break )
7083
71- dims =
72- if axis do
73- [ axis ]
74- else
75- Nx . axes ( op_typespec . shape )
76- end
77-
78- shape = remove_axes ( op_typespec . shape , dims )
84+ shape = Tuple . delete_at ( op_typespec . shape , axis )
7985 typespecs = [ Typespec . tensor ( op_typespec . type , shape ) , Typespec . tensor ( type , shape ) ]
8086
8187 [ _ , result ] =
82- Value . reduce ( reduction , [ init_value , index_init_value ] , [ op , iota ] , dims , typespecs )
88+ Value . reduce ( reduction , [ init_value , index_init_value ] , [ op , iota ] , [ axis ] , typespecs )
8389
8490 if keep_axis do
8591 Value . reshape ( result , Typespec . tensor ( type , put_elem ( op_typespec . shape , axis , 1 ) ) )
@@ -88,13 +94,7 @@ defmodule EXLA.Lib do
8894 end
8995 end
9096
91- defp remove_axes ( shape , axes ) do
92- axes
93- |> Enum . reverse ( )
94- |> Enum . reduce ( shape , & Tuple . delete_at ( & 2 , & 1 ) )
95- end
96-
97- defp create_min_max_computation ( % Function { } = function , type , index_type , is_min? , tie_break ) do
97+ defp create_min_max_computation ( % Function { } = function , type , index_type , variant , tie_break ) do
9898 arg_typespecs = [
9999 Typespec . tensor ( type , { } ) ,
100100 Typespec . tensor ( index_type , { } ) ,
@@ -109,27 +109,42 @@ defmodule EXLA.Lib do
109109 value_typespec = Typespec . tensor ( type , { } )
110110 idx_typespec = Typespec . tensor ( index_type , { } )
111111
112- cmp =
113- if is_min? ,
114- do: Value . less_equal ( lhs_value , rhs_value , pred_typespec ) ,
115- else: Value . greater_equal ( lhs_value , rhs_value , pred_typespec )
112+ comparator =
113+ case variant do
114+ :min -> & Value . less / 3
115+ :max -> & Value . greater / 3
116+ end
117+
118+ # Pick lhs if strictly before or if it is NaN
119+ pick_lhs_value =
120+ Value . bitwise_or (
121+ comparator . ( lhs_value , rhs_value , pred_typespec ) ,
122+ Value . is_nan ( lhs_value , pred_typespec ) ,
123+ pred_typespec
124+ )
116125
117- max = Value . select ( cmp , lhs_value , rhs_value , value_typespec )
118- arg_max = Value . select ( cmp , lhs_index , rhs_index , idx_typespec )
126+ max = Value . select ( pick_lhs_value , lhs_value , rhs_value , value_typespec )
119127
120- arg_max =
128+ idx_comparator =
121129 case tie_break do
122- :low ->
123- eq? = Value . equal ( lhs_value , rhs_value , pred_typespec )
124- id = Value . min ( lhs_index , rhs_index , idx_typespec )
125- Value . select ( eq? , id , arg_max , idx_typespec )
126-
127- :high ->
128- eq? = Value . equal ( lhs_value , rhs_value , pred_typespec )
129- id = Value . max ( lhs_index , rhs_index , idx_typespec )
130- Value . select ( eq? , id , arg_max , idx_typespec )
130+ :low -> & Value . less / 3
131+ :high -> & Value . greater / 3
131132 end
132133
134+ # If lhs and rhs are equal (and not NaN), then pick index based on tie_break
135+ pick_lhs_idx =
136+ Value . bitwise_or (
137+ pick_lhs_value ,
138+ Value . bitwise_and (
139+ Value . equal ( lhs_value , rhs_value , pred_typespec ) ,
140+ idx_comparator . ( lhs_index , rhs_index , pred_typespec ) ,
141+ pred_typespec
142+ ) ,
143+ pred_typespec
144+ )
145+
146+ arg_max = Value . select ( pick_lhs_idx , lhs_index , rhs_index , idx_typespec )
147+
133148 Value . return ( function , [ max , arg_max ] )
134149 Function . pop_region ( function )
135150 region
0 commit comments