@@ -17,9 +17,22 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
1717 end
1818end
1919
20+ mutable struct TracedRScalar{T} <: RScalar{T}
21+ paths:: Tuple
22+ mlir_data:: Union{Nothing,MLIR.IR.Value}
23+
24+ function TracedRScalar {T} (
25+ paths:: Tuple , mlir_data:: Union{Nothing,MLIR.IR.Value}
26+ ) where {T}
27+ if ! isnothing (mlir_data)
28+ @assert size (MLIR. IR. type (mlir_data)) == ()
29+ end
30+ return new {T} (paths, mlir_data)
31+ end
32+ end
33+
2034const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
2135const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
22- const AnyTracedRScalar{T} = AnyTracedRArray{T,0 }
2336const AnyTracedRVector{T} = AnyTracedRArray{T,1 }
2437const AnyTracedRMatrix{T} = AnyTracedRArray{T,2 }
2538const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}
@@ -38,12 +51,12 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...)
3851 return get_ancestor_indices (parent (x), Base. reindex (parentindices (x), indices)... )
3952end
4053
41- Base. getindex (a:: AnyTracedRScalar {T} ) where {T} = a
54+ Base. getindex (a:: TracedRScalar {T} ) where {T} = a
4255
43- Base. zero (:: AnyTracedRScalar {T} ) where {T} = promote_to (TracedRArray{T, 0 }, zero (T))
44- Base. one (:: AnyTracedRScalar {T} ) where {T} = promote_to (TracedRArray{T, 0 }, one (T))
56+ Base. zero (:: TracedRScalar {T} ) where {T} = promote_to (TracedRScalar{T }, zero (T))
57+ Base. one (:: TracedRScalar {T} ) where {T} = promote_to (TracedRScalar{T }, one (T))
4558
46- function Base. convert (:: Type{<:AnyTracedRScalar {T}} , x:: Number ) where {T}
59+ function Base. convert (:: Type{<:TracedRScalar {T}} , x:: Number ) where {T}
4760 return promote_to (TracedRArray{T,0 }, T (x))
4861end
4962
@@ -71,7 +84,7 @@ and require expensive copies and synchronization each time and therefore should
7184 ),
7285 1 ,
7386 )
74- return TracedRArray {T,0 } ((), res2, () )
87+ return TracedRScalar {T } ((), res2)
7588end
7689
7790function Base. getindex (a:: TracedRArray{T,N} , indices:: Vararg{Any,N} ) where {T,N}
@@ -131,7 +144,11 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
131144 # return print(io, X.mlir_data, ")")
132145end
133146
134- Base. only (A:: AnyTracedRScalar{T} ) where {T} = A
147+ function Base. show (io:: IOty , X:: TracedRScalar{T} ) where {T,IOty<: Union{IO,IOContext} }
148+ return print (io, " TracedRScalar{" , T, " }(" , X. paths, " )" )
149+ end
150+
151+ Base. only (A:: TracedRScalar{T} ) where {T} = A
135152
136153function Base. reshape (A:: AnyTracedRArray{T,N} , dims:: NTuple{NT,Int} ) where {T,N,NT}
137154 if prod (dims) != prod (size (A))
205222
206223function promote_to (:: Type{TracedRArray{T,N}} , rhs) where {T,N}
207224 if isa (rhs, TracedRArray)
208- if typeof (rhs) == TracedRArray{T,N}
209- return rhs
210- end
225+ rhs isa TracedRArray{T,N} && return rhs
211226 return TracedRArray {T,N} (
212227 (),
213228 MLIR. IR. result (
@@ -220,11 +235,8 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
220235 )
221236 end
222237 if isa (rhs, Number)
223- attr = fill (MLIR. IR. Attribute (T (rhs)), mlir_type (TracedRArray{T,N}, size (rhs)))
224- ta = TracedRArray {T,N} (
225- (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 ), size (rhs)
226- )
227- return ta
238+ throw (ArgumentError (" Cannot promote number to `TracedRArray`. Use \
239+ `TracedRScalar` instead." ))
228240 end
229241 T0 = eltype (rhs)
230242 attr = MLIR. IR. DenseElementsAttribute (collect (rhs))
@@ -236,9 +248,41 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
236248 )
237249end
238250
251+ function promote_to (:: Type{TracedRScalar{T}} , rhs) where {T}
252+ if isa (rhs, TracedRScalar)
253+ rhs isa TracedRScalar{T} && return rhs
254+ return TracedRScalar {T} (
255+ (),
256+ MLIR. IR. result (
257+ MLIR. Dialects. stablehlo. convert (
258+ rhs. mlir_data; result= mlir_type (TracedRScalar{T})
259+ ),
260+ 1 ,
261+ ),
262+ )
263+ end
264+ if isa (rhs, Number)
265+ attr = fill (MLIR. IR. Attribute (T (rhs)), mlir_type (TracedRScalar{T}))
266+ return TracedRScalar {T} (
267+ (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 )
268+ )
269+ end
270+ T0 = eltype (rhs)
271+ attr = MLIR. IR. DenseElementsAttribute (collect (rhs))
272+ return promote_to (
273+ TracedRScalar{T},
274+ TracedRScalar {T0} (
275+ (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 )
276+ ),
277+ )
278+ end
279+
239280function promote_to (:: TracedRArray{T,N} , rhs) where {T,N}
240281 return promote_to (TracedRArray{T,N}, rhs)
241282end
283+ function promote_to (:: TracedRScalar{T} , rhs) where {T}
284+ return promote_to (TracedRScalar{T}, rhs)
285+ end
242286
243287for (jlop, hloop) in (
244288 (:(Base. min), :minimum ),
0 commit comments