1919
2020TracedRArray {T,N} (x:: TracedRArray{T,N} ) where {T,N} = x
2121
22+ mutable struct TracedRScalar{T} <: RScalar{T}
23+ paths:: Tuple
24+ mlir_data:: Union{Nothing,MLIR.IR.Value}
25+
26+ function TracedRScalar {T} (
27+ paths:: Tuple , mlir_data:: Union{Nothing,MLIR.IR.Value}
28+ ) where {T}
29+ if ! isnothing (mlir_data)
30+ @assert size (MLIR. IR. type (mlir_data)) == ()
31+ end
32+ return new {T} (paths, mlir_data)
33+ end
34+ end
35+
2236const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
2337const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
24- const AnyTracedRScalar{T} = AnyTracedRArray{T,0 }
2538const AnyTracedRVector{T} = AnyTracedRArray{T,1 }
2639const AnyTracedRMatrix{T} = AnyTracedRArray{T,2 }
2740const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}
@@ -40,12 +53,12 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...)
4053 return get_ancestor_indices (parent (x), Base. reindex (parentindices (x), indices)... )
4154end
4255
43- Base. getindex (a:: AnyTracedRScalar {T} ) where {T} = a
56+ Base. getindex (a:: TracedRScalar {T} ) where {T} = a
4457
45- Base. zero (:: AnyTracedRScalar {T} ) where {T} = promote_to (TracedRArray{T, 0 }, zero (T))
46- Base. one (:: AnyTracedRScalar {T} ) where {T} = promote_to (TracedRArray{T, 0 }, one (T))
58+ Base. zero (:: TracedRScalar {T} ) where {T} = promote_to (TracedRScalar{T }, zero (T))
59+ Base. one (:: TracedRScalar {T} ) where {T} = promote_to (TracedRScalar{T }, one (T))
4760
48- function Base. convert (:: Type{<:AnyTracedRScalar {T}} , x:: Number ) where {T}
61+ function Base. convert (:: Type{<:TracedRScalar {T}} , x:: Number ) where {T}
4962 return promote_to (TracedRArray{T,0 }, T (x))
5063end
5164
@@ -73,7 +86,7 @@ and require expensive copies and synchronization each time and therefore should
7386 ),
7487 1 ,
7588 )
76- return TracedRArray {T,0 } ((), res2, () )
89+ return TracedRScalar {T } ((), res2)
7790end
7891
7992function Base. getindex (a:: TracedRArray{T,N} , indices:: Vararg{Any,N} ) where {T,N}
@@ -133,7 +146,11 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
133146 # return print(io, X.mlir_data, ")")
134147end
135148
136- Base. only (A:: AnyTracedRScalar{T} ) where {T} = A
149+ function Base. show (io:: IOty , X:: TracedRScalar{T} ) where {T,IOty<: Union{IO,IOContext} }
150+ return print (io, " TracedRScalar{" , T, " }(" , X. paths, " )" )
151+ end
152+
153+ Base. only (A:: TracedRScalar{T} ) where {T} = A
137154
138155function Base. reshape (A:: AnyTracedRArray{T,N} , dims:: NTuple{NT,Int} ) where {T,N,NT}
139156 if prod (dims) != prod (size (A))
207224
208225function promote_to (:: Type{TracedRArray{T,N}} , rhs) where {T,N}
209226 if isa (rhs, TracedRArray)
210- if typeof (rhs) == TracedRArray{T,N}
211- return rhs
212- end
227+ rhs isa TracedRArray{T,N} && return rhs
213228 return TracedRArray {T,N} (
214229 (),
215230 MLIR. IR. result (
@@ -222,11 +237,8 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
222237 )
223238 end
224239 if isa (rhs, Number)
225- attr = fill (MLIR. IR. Attribute (T (rhs)), mlir_type (TracedRArray{T,N}, size (rhs)))
226- ta = TracedRArray {T,N} (
227- (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 ), size (rhs)
228- )
229- return ta
240+ throw (ArgumentError (" Cannot promote number to `TracedRArray`. Use \
241+ `TracedRScalar` instead." ))
230242 end
231243 T0 = eltype (rhs)
232244 attr = MLIR. IR. DenseElementsAttribute (collect (rhs))
@@ -238,9 +250,41 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
238250 )
239251end
240252
253+ function promote_to (:: Type{TracedRScalar{T}} , rhs) where {T}
254+ if isa (rhs, TracedRScalar)
255+ rhs isa TracedRScalar{T} && return rhs
256+ return TracedRScalar {T} (
257+ (),
258+ MLIR. IR. result (
259+ MLIR. Dialects. stablehlo. convert (
260+ rhs. mlir_data; result= mlir_type (TracedRScalar{T})
261+ ),
262+ 1 ,
263+ ),
264+ )
265+ end
266+ if isa (rhs, Number)
267+ attr = fill (MLIR. IR. Attribute (T (rhs)), mlir_type (TracedRScalar{T}))
268+ return TracedRScalar {T} (
269+ (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 )
270+ )
271+ end
272+ T0 = eltype (rhs)
273+ attr = MLIR. IR. DenseElementsAttribute (collect (rhs))
274+ return promote_to (
275+ TracedRScalar{T},
276+ TracedRScalar {T0} (
277+ (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 )
278+ ),
279+ )
280+ end
281+
241282function promote_to (:: TracedRArray{T,N} , rhs) where {T,N}
242283 return promote_to (TracedRArray{T,N}, rhs)
243284end
285+ function promote_to (:: TracedRScalar{T} , rhs) where {T}
286+ return promote_to (TracedRScalar{T}, rhs)
287+ end
244288
245289for (jlop, hloop) in (
246290 (:(Base. min), :minimum ),
0 commit comments