@@ -455,18 +455,41 @@ function Base.fill!(A::TracedRArray{T,N}, x::TracedRNumber{T2}) where {T,N,T2}
455455 return A
456456end
457457
458+ function Base. _cat (dims:: Val{D} , A:: TracedRArray{T,N} , Bs:: TracedRArray... ) where {T,N,D}
459+ @assert D isa Integer " Support for non-integer dimensions is not implemented yet."
460+
461+ # MLIR expects the dimension `D` to be ≤ the rank of the input tensors
462+ A = maybe_expand_dims (A, dims)
463+ Bs = maybe_expand_dims .(Bs, (dims,))
464+
465+ catdims = Base. dims2cat (dims)
466+ shape = Base. cat_size_shape (catdims, A, Bs... )
467+ RT = Base. promote_eltype (A, Bs... )
468+ Res = TracedRArray {RT,length(shape)} (
469+ (),
470+ MLIR. IR. result (
471+ MLIR. Dialects. stablehlo. concatenate (
472+ [A. mlir_data, [B. mlir_data for B in Bs]. .. ];
473+ result_0= MLIR. IR. TensorType (shape, MLIR. IR. Type (RT)),
474+ dimension= D - 1 , # stablehlo expects this to be zero-indexed
475+ ),
476+ 1 ,
477+ ),
478+ shape,
479+ )
480+ return Res
481+ end
482+
483+ function maybe_expand_dims (x:: AbstractArray{T,N} , :: Val{D} ) where {T,N,D}
484+ D ≤ N && return x
485+ return reshape (x, ntuple (i -> i ≤ N ? size (x, i) : 1 , Val (D)))
486+ end
487+
458488struct AbstractReactantArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end
459489
460490AbstractReactantArrayStyle (:: Val{N} ) where {N} = AbstractReactantArrayStyle {N} ()
461491AbstractReactantArrayStyle {M} (:: Val{N} ) where {N,M} = AbstractReactantArrayStyle {N} ()
462492
463- # function Broadcast.materialize(bc::Broadcasted)
464- # @show bc
465- # inst = instantiate(bc)
466- # @show inst
467- # copy(inst)
468- # end
469-
470493function BroadcastStyle (:: Type{<:AnyTracedRArray{T,N}} ) where {T,N}
471494 return AbstractReactantArrayStyle {N} ()
472495end
@@ -628,33 +651,3 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted)
628651 dest. mlir_data = res. mlir_data
629652 return dest
630653end
631-
632- function Base. _cat (dims:: Val{D} , A:: TracedRArray{T,N} , Bs:: TracedRArray... ) where {T,N,D}
633- @assert D isa Integer " Support for non-integer dimensions is not implemented yet."
634-
635- # MLIR expects the dimension `D` to be ≤ the rank of the input tensors
636- A = maybe_expand_dims (A, dims)
637- Bs = maybe_expand_dims .(Bs, (dims,))
638-
639- catdims = Base. dims2cat (dims)
640- shape = Base. cat_size_shape (catdims, A, Bs... )
641- RT = Base. promote_eltype (A, Bs... )
642- Res = TracedRArray {RT,length(shape)} (
643- (),
644- MLIR. IR. result (
645- MLIR. Dialects. stablehlo. concatenate (
646- [A. mlir_data, [B. mlir_data for B in Bs]. .. ];
647- result_0= MLIR. IR. TensorType (shape, MLIR. IR. Type (RT)),
648- dimension= D - 1 , # stablehlo expects this to be zero-indexed
649- ),
650- 1 ,
651- ),
652- shape,
653- )
654- return Res
655- end
656-
657- function maybe_expand_dims (x:: AbstractArray{T,N} , :: Val{D} ) where {T,N,D}
658- D ≤ N && return x
659- return reshape (x, ntuple (i -> i ≤ N ? size (x, i) : 1 , Val (D)))
660- end
0 commit comments