@@ -66,6 +66,10 @@ and require expensive copies and synchronization each time and therefore should
6666 return TracedRNumber {T} ((), res2)
6767end
6868
69+ function Base. getindex (a:: TracedRArray{T,0} ) where {T}
70+ return TracedRNumber {T} ((), a. mlir_data)
71+ end
72+
6973function Base. getindex (a:: TracedRArray{T,N} , indices:: Vararg{Any,N} ) where {T,N}
7074 indices = [i isa Colon ? (1 : size (a, idx)) : i for (idx, i) in enumerate (indices)]
7175 res = MLIR. IR. result (
@@ -222,7 +226,12 @@ function elem_apply(
222226end
223227
224228function elem_apply (f, args:: Vararg{Any,Nargs} ) where {Nargs}
225- all (iszero ∘ ndims, args) && return f (args... )
229+ if all (iszero ∘ ndims, args)
230+ scalar_args = map (args) do arg
231+ return promote_to (TracedRNumber{eltype (arg)}, arg)
232+ end
233+ return f (scalar_args... )
234+ end
226235
227236 fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn (
228237 f, args, (), string (f) * " _broadcast_scalar" , false ; toscalar= true
@@ -440,6 +449,12 @@ function Base.fill!(A::TracedRArray{T,N}, x) where {T,N}
440449 return A
441450end
442451
452+ function Base. fill! (A:: TracedRArray{T,N} , x:: TracedRNumber{T2} ) where {T,N,T2}
453+ bcast = broadcast_to_size (promote_to (TracedRNumber{T}, x), size (A))
454+ A. mlir_data = bcast. mlir_data
455+ return A
456+ end
457+
443458struct AbstractReactantArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end
444459
445460AbstractReactantArrayStyle (:: Val{N} ) where {N} = AbstractReactantArrayStyle {N} ()
458473
459474function Base. similar (
460475 bc:: Broadcasted{AbstractReactantArrayStyle{N}} , :: Type{T} , dims
461- ) where {T,N}
476+ ) where {T<: ReactantPrimitives ,N}
477+ @assert N isa Int
478+ return TracedRArray {T,N} ((), nothing , map (length, dims))
479+ end
480+
481+ function Base. similar (
482+ bc:: Broadcasted{AbstractReactantArrayStyle{N}} , :: Type{<:TracedRNumber{T}} , dims
483+ ) where {T<: ReactantPrimitives ,N}
462484 @assert N isa Int
463485 return TracedRArray {T,N} ((), nothing , map (length, dims))
464486end
@@ -536,7 +558,7 @@ function broadcast_to_size(arg::T, rsize) where {T<:Number}
536558end
537559
538560function broadcast_to_size (arg:: TracedRNumber , rsize)
539- rsize == () && return arg
561+ length ( rsize) == 0 && return arg
540562 mlirty = MLIR. IR. type (arg. mlir_data)
541563 return TracedRArray {eltype(arg),length(rsize)} (
542564 (),
0 commit comments