@@ -764,30 +764,30 @@ end
764764dispatch_val (x) = x
765765dispatch_val (:: Val{D} ) where {D} = D
766766
767- function Base. _cat (dims, A :: TracedRArray{T,N } , Bs :: TracedRArray... ) where {T,N }
767+ function Base. _cat_t (dims, :: Type{T } , X :: TracedRArray... ) where {T}
768768 dims = dispatch_val (dims)
769769 @assert dims isa Integer " Support for non-integer dimensions is not implemented yet."
770770
771771 # MLIR expects the dimension `dims` to be ≤ the rank of the input tensors
772- A = maybe_expand_dims (A, dims)
773- Bs = maybe_expand_dims .(Bs, (dims,))
772+ X = maybe_expand_dims .(X, (dims,))
774773
775774 catdims = Base. dims2cat (dims)
776- shape = Base. cat_size_shape (catdims, A, Bs... )
777- RT = Base. promote_eltype (A, Bs... )
778- Res = TracedRArray {RT,length(shape)} (
775+ shape = Base. cat_size_shape (catdims, X... )
776+ RT = Base. promote_eltype (T, X... )
777+
778+ return TracedRArray {RT,length(shape)} (
779779 (),
780780 MLIR. IR. result (
781+ # TODO maybe we should do some conversion?
781782 MLIR. Dialects. stablehlo. concatenate (
782- [A . mlir_data, [B . mlir_data for B in Bs] . .. ] ;
783+ get_mlir_data .(X) ;
783784 result_0= MLIR. IR. TensorType (shape, MLIR. IR. Type (RT)),
784785 dimension= dims - 1 , # stablehlo expects this to be zero-indexed
785786 ),
786787 1 ,
787788 ),
788789 shape,
789790 )
790- return Res
791791end
792792
793793function maybe_expand_dims (x:: AbstractArray{T,N} , dims) where {T,N}
0 commit comments