1919
2020TracedRArray {T,N} (x:: TracedRArray{T,N} ) where {T,N} = x
2121
22+ function Base. setproperty! (x:: TracedRArray , f:: Symbol , v)
23+ if f === :mlir_data && ! isnothing (v)
24+ @assert size (MLIR. IR. type (v)) == size (x)
25+ end
26+ return setfield! (x, f, v)
27+ end
28+
2229mutable struct TracedRScalar{T} <: RScalar{T}
2330 paths:: Tuple
2431 mlir_data:: Union{Nothing,MLIR.IR.Value}
@@ -33,6 +40,15 @@ mutable struct TracedRScalar{T} <: RScalar{T}
3340 end
3441end
3542
43+ function Base. setproperty! (x:: TracedRScalar , f:: Symbol , v)
44+ if f === :mlir_data && ! isnothing (v)
45+ @assert size (MLIR. IR. type (v)) == ()
46+ end
47+ return setfield! (x, f, v)
48+ end
49+
50+ Base. eltype (:: Type{TracedRScalar{T}} ) where {T} = T
51+
3652const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
3753const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
3854const AnyTracedRVector{T} = AnyTracedRArray{T,1 }
@@ -59,7 +75,7 @@ Base.zero(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, zero(T))
5975Base. one (:: TracedRScalar{T} ) where {T} = promote_to (TracedRScalar{T}, one (T))
6076
6177function Base. convert (:: Type{<:TracedRScalar{T}} , x:: Number ) where {T}
62- return promote_to (TracedRArray{T, 0 }, T (x))
78+ return promote_to (TracedRScalar{T }, T (x))
6379end
6480
6581function Base. getindex (a:: TracedRArray{T,N} , index:: Vararg{Int,N} ) where {T,N}
@@ -121,7 +137,7 @@ function Base.setindex!(
121137 a:: TracedRArray{T,N} , v, indices:: Vararg{Union{Base.AbstractUnitRange,Colon},N}
122138) where {T,N}
123139 indices = [
124- (promote_to (TracedRArray {Int, 0 }, i isa Colon ? 1 : first (i)) - 1 ). mlir_data for
140+ (promote_to (TracedRScalar {Int}, i isa Colon ? 1 : first (i)) - 1 ). mlir_data for
125141 i in indices
126142 ]
127143 v = promote_to (TracedRArray{T,N}, v)
@@ -222,6 +238,14 @@ function Base.promote_rule(::Type{T}, ::Type{TracedRArray{S,N}}) where {T,S,N}
222238 return TracedRArray{Base. promote_type (T, S),N}
223239end
224240
241+ function Base. promote_rule (:: Type{T} , :: Type{TracedRScalar{S}} ) where {T,S}
242+ return TracedRScalar{Base. promote_type (T, S)}
243+ end
244+
245+ function Base. convert (:: Type{TracedRScalar{T}} , x:: Number ) where {T}
246+ return promote_to (TracedRScalar{T}, x)
247+ end
248+
225249function promote_to (:: Type{TracedRArray{T,N}} , rhs) where {T,N}
226250 if isa (rhs, TracedRArray)
227251 rhs isa TracedRArray{T,N} && return rhs
@@ -279,12 +303,8 @@ function promote_to(::Type{TracedRScalar{T}}, rhs) where {T}
279303 )
280304end
281305
282- function promote_to (:: TracedRArray{T,N} , rhs) where {T,N}
283- return promote_to (TracedRArray{T,N}, rhs)
284- end
285- function promote_to (:: TracedRScalar{T} , rhs) where {T}
286- return promote_to (TracedRScalar{T}, rhs)
287- end
306+ promote_to (:: TracedRArray{T,N} , rhs) where {T,N} = promote_to (TracedRArray{T,N}, rhs)
307+ promote_to (:: TracedRScalar{T} , rhs) where {T} = promote_to (TracedRScalar{T}, rhs)
288308
289309for (jlop, hloop) in (
290310 (:(Base. min), :minimum ),
@@ -295,66 +315,35 @@ for (jlop, hloop) in (
295315 (:(Base.:/ ), :divide ),
296316 (:(Base.:^ ), :power ),
297317)
298- @eval begin
299- function $ (jlop)(
300- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs:: TracedRArray{T,0} )
301- ) where {T}
302- return TracedRArray {T,0} (
303- (),
304- MLIR. IR. result (
305- MLIR. Dialects. stablehlo.$ (hloop)(lhs. mlir_data, rhs. mlir_data), 1
306- ),
307- (),
308- )
309- end
310-
311- function $ (jlop)(
312- @nospecialize (lhs:: TracedRArray{T1,0} ), @nospecialize (rhs:: TracedRArray{T2,0} )
313- ) where {T1,T2}
314- commonTy = TracedRArray{Base. promote_type (T1, T2),0 }
315- lhs = promote_to (commonTy, lhs)
316- rhs = promote_to (commonTy, rhs)
317- return $ (jlop)(lhs, rhs)
318- end
319- end
320-
321- for otherType in (Number, Any)
322- @eval begin
323- function $ (jlop)(
324- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs:: $ (otherType))
325- ) where {T}
326- rhs = promote_to (lhs, rhs)
327- return $ (jlop)(lhs, rhs)
328- end
329-
330- function $ (jlop)(
331- @nospecialize (lhs:: $ (otherType)), @nospecialize (rhs:: TracedRArray{T,0} )
332- ) where {T}
333- lhs = promote_to (rhs, lhs)
334- return $ (jlop)(lhs, rhs)
335- end
336- end
318+ @eval function $ (jlop)(
319+ @nospecialize (lhs:: TracedRScalar{T} ), @nospecialize (rhs:: TracedRScalar{T} )
320+ ) where {T}
321+ return TracedRArray {T} (
322+ (),
323+ MLIR. IR. result (
324+ MLIR. Dialects. stablehlo.$ (hloop)(lhs. mlir_data, rhs. mlir_data), 1
325+ ),
326+ )
337327 end
338328end
339329
340330function Base. ifelse (
341- @nospecialize (pred:: TracedRArray {Bool,0 } ),
342- @nospecialize (x:: TracedRArray {T1,0 } ),
343- @nospecialize (y:: TracedRArray {T2,0 } )
331+ @nospecialize (pred:: TracedRScalar {Bool} ),
332+ @nospecialize (x:: TracedRScalar {T1} ),
333+ @nospecialize (y:: TracedRScalar {T2} )
344334) where {T1,T2}
345- return TracedRArray {promote_type(T1, T2),0 } (
335+ return TracedRScalar {promote_type(T1, T2)} (
346336 (),
347337 MLIR. IR. result (
348338 MLIR. Dialects. stablehlo. select (pred. mlir_data, x. mlir_data, y. mlir_data), 1
349339 ),
350- size (pred),
351340 )
352341end
353342
354- Base. abs2 (x:: Reactant.TracedRArray{T,0 } ) where {T} = x * conj (x)
343+ Base. abs2 (x:: Reactant.TracedRScalar{T } ) where {T} = x * conj (x)
355344
356345function Base. literal_pow (
357- :: Base.RefValue{typeof(^)} , x:: TracedRArray{T,0 } , :: Base.RefValue{Val{P}}
346+ :: Base.RefValue{typeof(^)} , x:: TracedRScalar{T } , :: Base.RefValue{Val{P}}
358347) where {T,P}
359348 return Base. literal_pow (^ , x, Val (P))
360349end
@@ -371,14 +360,10 @@ for (jlop, hloop) in (
371360 (:(Base. log), :log ),
372361 (:(Base. sqrt), :sqrt ),
373362)
374- @eval begin
375- function $jlop (@nospecialize (lhs:: TracedRArray{T,0} )) where {T}
376- return TracedRArray {T,0} (
377- (),
378- MLIR. IR. result (MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data), 1 ),
379- size (lhs),
380- )
381- end
363+ @eval function $ (jlop)(@nospecialize (lhs:: TracedRScalar{T} )) where {T}
364+ return TracedRScalar {T} (
365+ (), MLIR. IR. result (MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data), 1 )
366+ )
382367 end
383368end
384369
@@ -445,6 +430,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
445430 residx = 1
446431
447432 for a in linear_results
433+ @show a
448434 if has_residx (a)
449435 path = get_residx (a)
450436 set! (result, path[2 : end ], MLIR. IR. result (res, residx))
@@ -480,37 +466,22 @@ for (jlop, hloop, hlocomp, merge) in (
480466 (:(Base.:(<= )), :compare , " LE" , nothing ),
481467 (:(Base.:(< )), :compare , " LT" , nothing ),
482468)
483- @eval begin
484- function $ (jlop)(
485- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs:: TracedRArray{T,0} )
486- ) where {T}
487- return TracedRArray {Bool,0} (
488- (),
489- MLIR. IR. result (
490- MLIR. Dialects. stablehlo.$ hloop (
491- lhs. mlir_data,
492- rhs. mlir_data;
493- comparison_direction= MLIR. API. stablehloComparisonDirectionAttrGet (
494- MLIR. IR. context (), $ hlocomp
495- ),
469+ @eval function $ (jlop)(
470+ @nospecialize (lhs:: TracedRScalar{T} ), @nospecialize (rhs:: TracedRScalar{T} )
471+ ) where {T}
472+ return TracedRScalar {Bool} (
473+ (),
474+ MLIR. IR. result (
475+ MLIR. Dialects. stablehlo.$ (hloop)(
476+ lhs. mlir_data,
477+ rhs. mlir_data;
478+ comparison_direction= MLIR. API. stablehloComparisonDirectionAttrGet (
479+ MLIR. IR. context (), $ hlocomp
496480 ),
497- 1 ,
498481 ),
499- size (lhs),
500- )
501- end
502-
503- function $ (jlop)(
504- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs)
505- ) where {T}
506- return $ (jlop)(lhs, promote_to (lhs, rhs))
507- end
508-
509- function $ (jlop)(
510- @nospecialize (lhs), @nospecialize (rhs:: TracedRArray{T,0} )
511- ) where {T}
512- return $ (jlop)(promote_to (rhs, lhs), rhs)
513- end
482+ 1 ,
483+ ),
484+ )
514485 end
515486
516487 if merge != = nothing
@@ -600,7 +571,7 @@ function Base.mapreduce(
600571 fnbody = MLIR. IR. Block (in_tys, [MLIR. IR. Location () for arg in in_tys])
601572
602573 args = (
603- TracedRArray {T,0 } ((), MLIR. IR. argument (fnbody, i), ()) for
574+ TracedRScalar {T } ((), MLIR. IR. argument (fnbody, i), ()) for
604575 (i, ty) in enumerate (in_tys)
605576 )
606577
0 commit comments