@@ -3251,20 +3251,9 @@ defmodule Nx do
32513251 """
32523252 @ doc type: :shape , from_backend: false
32533253 def new_axis ( tensor , axis , name \\ nil ) when is_integer ( axis ) do
3254- apply_vectorized ( tensor , fn tensor , offset ->
3255- % { shape: shape , names: names } = tensor = to_tensor ( tensor )
3256- rank = tuple_size ( shape )
3257- norm = if axis < 0 , do: axis + rank + 1 , else: axis + offset
3258-
3259- if norm not in offset .. tuple_size ( shape ) do
3260- raise ArgumentError ,
3261- "new axis position for shape #{ inspect ( shape ) } must be " <>
3262- "a number between #{ - rank - 1 + offset } and #{ rank - offset } , got: #{ axis } "
3263- end
3264-
3265- new_shape = Tuple . insert_at ( shape , norm , 1 )
3266- new_names = List . insert_at ( names , norm , name )
3267- impl! ( tensor ) . reshape ( % { tensor | shape: new_shape , names: new_names } , tensor )
3254+ apply_vectorized ( tensor , fn % { shape: shape , names: names } = tensor , offset ->
3255+ { shape , names , _axis } = Nx.Shape . new_axis ( shape , names , axis , name , 1 , offset )
3256+ impl! ( tensor ) . reshape ( % { tensor | shape: shape , names: names } , tensor )
32683257 end )
32693258 end
32703259
@@ -14668,28 +14657,35 @@ defmodule Nx do
1466814657 t
1466914658
1467014659 [ _ | _ ] = tensors ->
14671- [ % T { vectorized_axes: vectorized_axes } | _ ] =
14672- tensors = broadcast_vectors ( tensors , align_ranks: true )
14660+ concatenate_or_stack (
14661+ tensors ,
14662+ fn shapes , names , offset -> Nx.Shape . concatenate ( shapes , names , axis , offset ) end ,
14663+ fn out , tensors , axis -> list_impl! ( tensors ) . concatenate ( out , tensors , axis ) end
14664+ )
14665+ end
14666+ end
1467314667
14674- offset = length ( vectorized_axes )
14675- tensors = if vectorized_axes != [ ] , do: Enum . map ( tensors , & devectorize / 1 ) , else: tensors
14668+ defp concatenate_or_stack ( tensors , shape_and_name , callback ) do
14669+ [ % T { vectorized_axes: vectorized_axes } | _ ] =
14670+ tensors = broadcast_vectors ( tensors , align_ranks: true )
1467614671
14677- { types , [ s1 | _ ] = shapes , [ n1 | _ ] = names } =
14678- Enum . reduce ( tensors , { [ ] , [ ] , [ ] } , fn
14679- % T { type: t , shape: s , names: n } , { types , shapes , names } ->
14680- { [ t | types ] , [ s | shapes ] , [ n | names ] }
14681- end )
14672+ offset = length ( vectorized_axes )
14673+ tensors = if vectorized_axes != [ ] , do: Enum . map ( tensors , & devectorize / 1 ) , else: tensors
14674+
14675+ { types , shapes , names } =
14676+ Enum . reduce ( tensors , { [ ] , [ ] , [ ] } , fn
14677+ % T { type: t , shape: s , names: n } , { types , shapes , names } ->
14678+ { [ t | types ] , [ s | shapes ] , [ n | names ] }
14679+ end )
1468214680
14683- axis = Nx.Shape . normalize_axis ( s1 , axis , n1 , offset )
14684- output_type = Enum . reduce ( types , & Nx.Type . merge / 2 )
14681+ output_type = Enum . reduce ( types , & Nx.Type . merge / 2 )
1468514682
14686- { output_shape , output_names } =
14687- Nx.Shape . concatenate ( Enum . reverse ( shapes ) , Enum . reverse ( names ) , axis )
14683+ { output_shape , output_names , axis } =
14684+ shape_and_name . ( Enum . reverse ( shapes ) , Enum . reverse ( names ) , offset )
1468814685
14689- out = % { hd ( tensors ) | type: output_type , shape: output_shape , names: output_names }
14690- result = list_impl! ( tensors ) . concatenate ( out , tensors , axis )
14691- vectorize ( result , vectorized_axes )
14692- end
14686+ out = % { hd ( tensors ) | type: output_type , shape: output_shape , names: output_names }
14687+ result = callback . ( out , tensors , axis )
14688+ vectorize ( result , vectorized_axes )
1469314689 end
1469414690
1469514691 defp flatten_list_or_container ( list ) when is_list ( list ) do
@@ -14807,16 +14803,26 @@ defmodule Nx do
1480714803 >
1480814804
1480914805 """
14810- @ doc type: :ndim , from_backend: false
14806+ @ doc type: :ndim
1481114807 def stack ( tensors , opts \\ [ ] ) do
1481214808 opts = keyword! ( opts , axis: 0 , name: nil )
1481314809 axis = opts [ :axis ]
1481414810 name = opts [ :name ]
1481514811
14816- tensors
14817- |> flatten_list_or_container ( )
14818- |> Enum . map ( & Nx . new_axis ( & 1 , axis , name ) )
14819- |> Nx . concatenate ( axis: axis )
14812+ case flatten_list_or_container ( tensors ) do
14813+ [ ] ->
14814+ raise ArgumentError , "no tensors were given to stack"
14815+
14816+ [ t ] ->
14817+ Nx . new_axis ( t , axis , name )
14818+
14819+ [ _ | _ ] = tensors ->
14820+ concatenate_or_stack (
14821+ tensors ,
14822+ fn shapes , names , offset -> Nx.Shape . stack ( shapes , names , axis , name , offset ) end ,
14823+ fn out , tensors , axis -> list_impl! ( tensors ) . stack ( out , tensors , axis ) end
14824+ )
14825+ end
1482014826 end
1482114827
1482214828 @ doc """
0 commit comments