@@ -761,32 +761,87 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted)
761761 return dest
762762end
763763
764- function Base . _cat (dims :: Val{D} , A :: TracedRArray{T,N} , Bs :: TracedRArray... ) where {T,N,D}
765- @assert D isa Integer " Support for non-integer dimensions is not implemented yet. "
764+ dispatch_val (x) = x
765+ dispatch_val ( :: Val{D} ) where {D} = D
766766
767- # MLIR expects the dimension `D` to be ≤ the rank of the input tensors
768- A = maybe_expand_dims (A, dims)
769- Bs = maybe_expand_dims .(Bs, (dims,))
767+ @inline function Base. _typed_vcat (
768+ :: Type{T} , X:: Base.AbstractVecOrTuple{<:TracedRArray}
769+ ) where {T}
770+ return Base. _cat_t (Val (1 ), T, X... )
771+ end
772+ @inline function Base. _typed_hcat (
773+ :: Type{T} , X:: Base.AbstractVecOrTuple{<:TracedRArray}
774+ ) where {T}
775+ return Base. _cat_t (Val (2 ), T, X... )
776+ end
777+
778+ # `Base.typed_hvcat` is overloaded for `AbstractVecOrMat` using `setindex!` that breaks Reactant
779+ # generic implementation uses `typed_hcat` and `typed_vcat` which is alright
780+ @inline function Base. typed_hvcat (
781+ :: Type{T} , rows:: Tuple{Vararg{Int}} , as:: TracedRArray...
782+ ) where {T}
783+ return invoke (
784+ Base. typed_hvcat, Tuple{Type{T},Tuple{Vararg{Int}},Vararg{Any}}, T, rows, as...
785+ )
786+ end
787+
788+ function Base. _typed_hvncat (
789+ T:: Type , dims:: NTuple{N,Int} , row_first:: Bool , as:: TracedRArray...
790+ ) where {N}
791+ As = if row_first
792+ perm = [2 , 1 , 3 : N... ]
793+ dims = [dims[2 ], dims[1 ], dims[3 : end ]. .. ]
794+ permutedims (reshape (collect (as), dims... ), perm)
795+ else
796+ reshape (collect (as), dims)
797+ end
798+
799+ for d in 1 : N
800+ Bs = Array {Any,N - d} (undef, size (As)[2 : end ]. .. )
801+
802+ for (i, col) in
803+ zip (eachindex (Bs), eachslice (As; dims= Tuple (2 : ndims (As)), drop= true ))
804+ # TODO row_first affects the flattening?
805+ Bs[i] = Base. _cat_t (d, T, col... )
806+ end
807+
808+ As = Bs
809+ end
810+
811+ return only (As)
812+ end
813+
814+ function Base. _cat_t (dims, :: Type{T} , X:: TracedRArray... ) where {T}
815+ dims = dispatch_val (dims)
816+ @assert dims isa Integer " Support for non-integer dimensions is not implemented yet."
817+
818+ # MLIR expects the dimension `dims` to be ≤ the rank of the input tensors
819+ X = maybe_expand_dims .(X, (dims,))
770820
771821 catdims = Base. dims2cat (dims)
772- shape = Base. cat_size_shape (catdims, A, Bs... )
773- RT = Base. promote_eltype (A, Bs... )
774- Res = TracedRArray {RT,length(shape)} (
822+ shape = Base. cat_size_shape (catdims, X... )
823+ RT = Base. promote_eltype (T, X... )
824+
825+ # convert to the target eltype
826+ X = map (Base. Fix1 (promote_to, TracedRArray{RT,length (shape)}), X)
827+
828+ return TracedRArray {RT,length(shape)} (
775829 (),
776830 MLIR. IR. result (
831+ # TODO maybe we should do some conversion?
777832 MLIR. Dialects. stablehlo. concatenate (
778- [A . mlir_data, [B . mlir_data for B in Bs] . .. ] ;
833+ collect ( get_mlir_data .(X)) ;
779834 result_0= MLIR. IR. TensorType (shape, MLIR. IR. Type (RT)),
780- dimension= D - 1 , # stablehlo expects this to be zero-indexed
835+ dimension= dims - 1 , # stablehlo expects this to be zero-indexed
781836 ),
782837 1 ,
783838 ),
784839 shape,
785840 )
786- return Res
787841end
788842
789- function maybe_expand_dims (x:: AbstractArray{T,N} , :: Val{D} ) where {T,N,D}
790- D ≤ N && return x
791- return reshape (x, ntuple (i -> i ≤ N ? size (x, i) : 1 , Val (D)))
843+ function maybe_expand_dims (x:: AbstractArray{T,N} , dims) where {T,N}
844+ dims = dispatch_val (dims)
845+ dims ≤ N && return x
846+ return reshape (x, ntuple (i -> i ≤ N ? size (x, i) : 1 , dims))
792847end
0 commit comments