@@ -464,36 +464,6 @@ function Base.fill!(A::TracedRArray{T,N}, x::TracedRNumber{T2}) where {T,N,T2}
464464 return A
465465end
466466
467- function Base. _cat (dims:: Val{D} , A:: TracedRArray{T,N} , Bs:: TracedRArray... ) where {T,N,D}
468- @assert D isa Integer " Support for non-integer dimensions is not implemented yet."
469-
470- # MLIR expects the dimension `D` to be ≤ the rank of the input tensors
471- A = maybe_expand_dims (A, dims)
472- Bs = maybe_expand_dims .(Bs, (dims,))
473-
474- catdims = Base. dims2cat (dims)
475- shape = Base. cat_size_shape (catdims, A, Bs... )
476- RT = Base. promote_eltype (A, Bs... )
477- Res = TracedRArray {RT,length(shape)} (
478- (),
479- MLIR. IR. result (
480- MLIR. Dialects. stablehlo. concatenate (
481- [A. mlir_data, [B. mlir_data for B in Bs]. .. ];
482- result_0= MLIR. IR. TensorType (shape, MLIR. IR. Type (RT)),
483- dimension= D - 1 , # stablehlo expects this to be zero-indexed
484- ),
485- 1 ,
486- ),
487- shape,
488- )
489- return Res
490- end
491-
492- function maybe_expand_dims (x:: AbstractArray{T,N} , :: Val{D} ) where {T,N,D}
493- D ≤ N && return x
494- return reshape (x, ntuple (i -> i ≤ N ? size (x, i) : 1 , Val (D)))
495- end
496-
497467struct AbstractReactantArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end
498468
499469AbstractReactantArrayStyle (:: Val{N} ) where {N} = AbstractReactantArrayStyle {N} ()
@@ -648,3 +618,88 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted)
648618 dest. mlir_data = res. mlir_data
649619 return dest
650620end
621+
622+ dispatch_val (x) = x
623+ dispatch_val (:: Val{D} ) where {D} = D
624+
625+ @inline function Base. _typed_vcat (
626+ :: Type{T} , X:: Base.AbstractVecOrTuple{<:TracedRArray}
627+ ) where {T}
628+ return Base. _cat_t (Val (1 ), T, X... )
629+ end
630+ @inline function Base. _typed_hcat (
631+ :: Type{T} , X:: Base.AbstractVecOrTuple{<:TracedRArray}
632+ ) where {T}
633+ return Base. _cat_t (Val (2 ), T, X... )
634+ end
635+
636+ # `Base.typed_hvcat` is overloaded for `AbstractVecOrMat` using `setindex!` that breaks Reactant
637+ # generic implementation uses `typed_hcat` and `typed_vcat` which is alright
638+ @inline function Base. typed_hvcat (
639+ :: Type{T} , rows:: Tuple{Vararg{Int}} , as:: TracedRArray...
640+ ) where {T}
641+ return invoke (
642+ Base. typed_hvcat, Tuple{Type{T},Tuple{Vararg{Int}},Vararg{Any}}, T, rows, as...
643+ )
644+ end
645+
646+ function Base. _typed_hvncat (
647+ T:: Type , dims:: NTuple{N,Int} , row_first:: Bool , as:: TracedRArray...
648+ ) where {N}
649+ As = if row_first
650+ perm = [2 , 1 , 3 : N... ]
651+ dims = [dims[2 ], dims[1 ], dims[3 : end ]. .. ]
652+ permutedims (reshape (collect (as), dims... ), perm)
653+ else
654+ reshape (collect (as), dims)
655+ end
656+
657+ for d in 1 : N
658+ Bs = Array {Any,N - d} (undef, size (As)[2 : end ]. .. )
659+
660+ for (i, col) in
661+ zip (eachindex (Bs), eachslice (As; dims= Tuple (2 : ndims (As)), drop= true ))
662+ # TODO row_first affects the flattening?
663+ Bs[i] = Base. _cat_t (d, T, col... )
664+ end
665+
666+ As = Bs
667+ end
668+
669+ return only (As)
670+ end
671+
672+ function Base. _cat_t (dims, :: Type{T} , X:: TracedRArray... ) where {T}
673+ dims = dispatch_val (dims)
674+ @assert dims isa Integer " Support for non-integer dimensions is not implemented yet."
675+
676+ # MLIR expects the dimension `dims` to be ≤ the rank of the input tensors
677+ X = maybe_expand_dims .(X, (dims,))
678+
679+ catdims = Base. dims2cat (dims)
680+ shape = Base. cat_size_shape (catdims, X... )
681+ RT = Base. promote_eltype (T, X... )
682+
683+ # convert to the target eltype
684+ X = map (Base. Fix1 (promote_to, TracedRArray{RT,length (shape)}), X)
685+
686+ return TracedRArray {RT,length(shape)} (
687+ (),
688+ MLIR. IR. result (
689+ # TODO maybe we should do some conversion?
690+ MLIR. Dialects. stablehlo. concatenate (
691+ collect (get_mlir_data .(X));
692+ result_0= MLIR. IR. TensorType (shape, MLIR. IR. Type (RT)),
693+ dimension= dims - 1 , # stablehlo expects this to be zero-indexed
694+ ),
695+ 1 ,
696+ ),
697+ shape,
698+ )
699+ end
700+
701+ function maybe_expand_dims (x:: AbstractArray{T,N} , dims) where {T,N}
702+ dims = dispatch_val (dims)
703+ dims ≤ N && return x
704+ return reshape (x, ntuple (i -> i ≤ N ? size (x, i) : 1 , dims))
705+ end
0 commit comments