@@ -150,9 +150,6 @@ defmodule Nx.Defn.Grad do
150150 defp reduce_args ( :put_slice , % { data: % { args: [ arg , _ , update | _ ] } } , acc , fun ) ,
151151 do: fun . ( arg , fun . ( update , acc ) )
152152
153- defp reduce_args ( :take_along_axis , % { data: % { args: [ arg | _ ] } } , acc , fun ) ,
154- do: fun . ( arg , acc )
155-
156153 defp reduce_args ( :gather , % { data: % { args: [ arg | _ ] } } , acc , fun ) ,
157154 do: fun . ( arg , acc )
158155
@@ -663,44 +660,6 @@ defmodule Nx.Defn.Grad do
663660 [ { t , g } ]
664661 end
665662
666- defp grad ( :take_along_axis , [ t , i , axis ] , _ans , g ) do
667- num_elements = i |> Nx . shape ( ) |> Tuple . product ( )
668-
669- # Convert `i`, the take_along_axis indices, to a list of
670- # fully qualified (i.e. [0, 2, 1] for a {_, _, _}-shaped tensor)
671- # indices
672-
673- indices =
674- 0 .. ( Nx . rank ( g ) - 1 ) // 1
675- |> Enum . map ( fn
676- # For the axis of interest, we'll use the actual take_along_axis indices
677- ^ axis ->
678- Nx . reshape ( i , { num_elements , 1 } )
679-
680- axis ->
681- i
682- |> Nx . shape ( )
683- |> Nx . iota ( axis: axis )
684- |> Nx . reshape ( { num_elements , 1 } )
685- end )
686- |> Nx . concatenate ( axis: 1 )
687-
688- # Since g is produced through the given indices,
689- # we can reshape g to be a {num_elements} shaped tensor
690- # which will directly correspond to each of the reshaped
691- # indices above
692- updates = Nx . reshape ( g , { num_elements } )
693-
694- # The intuition for this grad is that for each index taken, we'll
695- # add the corresponding result grad to the original
696- g =
697- t
698- |> Expr . broadcast ( 0 , Nx . shape ( t ) , Nx . axes ( t ) )
699- |> Nx . indexed_add ( indices , updates )
700-
701- [ { t , g } ]
702- end
703-
704663 defp grad ( :gather , [ t , i , opts ] , _ans , g ) do
705664 i_axes = opts [ :axes ]
706665 i_shape = i . shape
@@ -714,6 +673,7 @@ defmodule Nx.Defn.Grad do
714673
715674 g =
716675 0
676+ |> Nx . as_type ( t . type )
717677 |> Nx . broadcast ( t_shape )
718678 |> Nx . indexed_add ( indices , updates , opts )
719679
0 commit comments