@@ -17,6 +17,13 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
1717 end
1818end
1919
20+ function Base. setproperty! (x:: TracedRArray , f:: Symbol , v)
21+ if f === :mlir_data && ! isnothing (v)
22+ @assert size (MLIR. IR. type (v)) == size (x)
23+ end
24+ return setfield! (x, f, v)
25+ end
26+
2027mutable struct TracedRScalar{T} <: RScalar{T}
2128 paths:: Tuple
2229 mlir_data:: Union{Nothing,MLIR.IR.Value}
@@ -31,6 +38,15 @@ mutable struct TracedRScalar{T} <: RScalar{T}
3138 end
3239end
3340
41+ function Base. setproperty! (x:: TracedRScalar , f:: Symbol , v)
42+ if f === :mlir_data && ! isnothing (v)
43+ @assert size (MLIR. IR. type (v)) == ()
44+ end
45+ return setfield! (x, f, v)
46+ end
47+
48+ Base. eltype (:: Type{TracedRScalar{T}} ) where {T} = T
49+
3450const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
3551const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
3652const AnyTracedRVector{T} = AnyTracedRArray{T,1 }
@@ -57,7 +73,7 @@ Base.zero(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, zero(T))
5773Base. one (:: TracedRScalar{T} ) where {T} = promote_to (TracedRScalar{T}, one (T))
5874
5975function Base. convert (:: Type{<:TracedRScalar{T}} , x:: Number ) where {T}
60- return promote_to (TracedRArray{T, 0 }, T (x))
76+ return promote_to (TracedRScalar{T }, T (x))
6177end
6278
6379function Base. getindex (a:: TracedRArray{T,N} , index:: Vararg{Int,N} ) where {T,N}
@@ -119,7 +135,7 @@ function Base.setindex!(
119135 a:: TracedRArray{T,N} , v, indices:: Vararg{Union{Base.AbstractUnitRange,Colon},N}
120136) where {T,N}
121137 indices = [
122- (promote_to (TracedRArray {Int, 0 }, i isa Colon ? 1 : first (i)) - 1 ). mlir_data for
138+ (promote_to (TracedRScalar {Int}, i isa Colon ? 1 : first (i)) - 1 ). mlir_data for
123139 i in indices
124140 ]
125141 v = promote_to (TracedRArray{T,N}, v)
@@ -220,6 +236,14 @@ function Base.promote_rule(::Type{T}, ::Type{TracedRArray{S,N}}) where {T,S,N}
220236 return TracedRArray{Base. promote_type (T, S),N}
221237end
222238
239+ function Base. promote_rule (:: Type{T} , :: Type{TracedRScalar{S}} ) where {T,S}
240+ return TracedRScalar{Base. promote_type (T, S)}
241+ end
242+
243+ function Base. convert (:: Type{TracedRScalar{T}} , x:: Number ) where {T}
244+ return promote_to (TracedRScalar{T}, x)
245+ end
246+
223247function promote_to (:: Type{TracedRArray{T,N}} , rhs) where {T,N}
224248 if isa (rhs, TracedRArray)
225249 rhs isa TracedRArray{T,N} && return rhs
@@ -277,12 +301,8 @@ function promote_to(::Type{TracedRScalar{T}}, rhs) where {T}
277301 )
278302end
279303
280- function promote_to (:: TracedRArray{T,N} , rhs) where {T,N}
281- return promote_to (TracedRArray{T,N}, rhs)
282- end
283- function promote_to (:: TracedRScalar{T} , rhs) where {T}
284- return promote_to (TracedRScalar{T}, rhs)
285- end
304+ promote_to (:: TracedRArray{T,N} , rhs) where {T,N} = promote_to (TracedRArray{T,N}, rhs)
305+ promote_to (:: TracedRScalar{T} , rhs) where {T} = promote_to (TracedRScalar{T}, rhs)
286306
287307for (jlop, hloop) in (
288308 (:(Base. min), :minimum ),
@@ -293,66 +313,35 @@ for (jlop, hloop) in (
293313 (:(Base.:/ ), :divide ),
294314 (:(Base.:^ ), :power ),
295315)
296- @eval begin
297- function $ (jlop)(
298- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs:: TracedRArray{T,0} )
299- ) where {T}
300- return TracedRArray {T,0} (
301- (),
302- MLIR. IR. result (
303- MLIR. Dialects. stablehlo.$ (hloop)(lhs. mlir_data, rhs. mlir_data), 1
304- ),
305- (),
306- )
307- end
308-
309- function $ (jlop)(
310- @nospecialize (lhs:: TracedRArray{T1,0} ), @nospecialize (rhs:: TracedRArray{T2,0} )
311- ) where {T1,T2}
312- commonTy = TracedRArray{Base. promote_type (T1, T2),0 }
313- lhs = promote_to (commonTy, lhs)
314- rhs = promote_to (commonTy, rhs)
315- return $ (jlop)(lhs, rhs)
316- end
317- end
318-
319- for otherType in (Number, Any)
320- @eval begin
321- function $ (jlop)(
322- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs:: $ (otherType))
323- ) where {T}
324- rhs = promote_to (lhs, rhs)
325- return $ (jlop)(lhs, rhs)
326- end
327-
328- function $ (jlop)(
329- @nospecialize (lhs:: $ (otherType)), @nospecialize (rhs:: TracedRArray{T,0} )
330- ) where {T}
331- lhs = promote_to (rhs, lhs)
332- return $ (jlop)(lhs, rhs)
333- end
334- end
316+ @eval function $ (jlop)(
317+ @nospecialize (lhs:: TracedRScalar{T} ), @nospecialize (rhs:: TracedRScalar{T} )
318+ ) where {T}
319+ return TracedRArray {T} (
320+ (),
321+ MLIR. IR. result (
322+ MLIR. Dialects. stablehlo.$ (hloop)(lhs. mlir_data, rhs. mlir_data), 1
323+ ),
324+ )
335325 end
336326end
337327
338328function Base. ifelse (
339- @nospecialize (pred:: TracedRArray {Bool,0 } ),
340- @nospecialize (x:: TracedRArray {T1,0 } ),
341- @nospecialize (y:: TracedRArray {T2,0 } )
329+ @nospecialize (pred:: TracedRScalar {Bool} ),
330+ @nospecialize (x:: TracedRScalar {T1} ),
331+ @nospecialize (y:: TracedRScalar {T2} )
342332) where {T1,T2}
343- return TracedRArray {promote_type(T1, T2),0 } (
333+ return TracedRScalar {promote_type(T1, T2)} (
344334 (),
345335 MLIR. IR. result (
346336 MLIR. Dialects. stablehlo. select (pred. mlir_data, x. mlir_data, y. mlir_data), 1
347337 ),
348- size (pred),
349338 )
350339end
351340
352- Base. abs2 (x:: Reactant.TracedRArray{T,0 } ) where {T} = x * conj (x)
341+ Base. abs2 (x:: Reactant.TracedRScalar{T } ) where {T} = x * conj (x)
353342
354343function Base. literal_pow (
355- :: Base.RefValue{typeof(^)} , x:: TracedRArray{T,0 } , :: Base.RefValue{Val{P}}
344+ :: Base.RefValue{typeof(^)} , x:: TracedRScalar{T } , :: Base.RefValue{Val{P}}
356345) where {T,P}
357346 return Base. literal_pow (^ , x, Val (P))
358347end
@@ -369,14 +358,10 @@ for (jlop, hloop) in (
369358 (:(Base. log), :log ),
370359 (:(Base. sqrt), :sqrt ),
371360)
372- @eval begin
373- function $jlop (@nospecialize (lhs:: TracedRArray{T,0} )) where {T}
374- return TracedRArray {T,0} (
375- (),
376- MLIR. IR. result (MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data), 1 ),
377- size (lhs),
378- )
379- end
361+ @eval function $ (jlop)(@nospecialize (lhs:: TracedRScalar{T} )) where {T}
362+ return TracedRScalar {T} (
363+ (), MLIR. IR. result (MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data), 1 )
364+ )
380365 end
381366end
382367
@@ -443,6 +428,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
443428 residx = 1
444429
445430 for a in linear_results
431+ @show a
446432 if has_residx (a)
447433 path = get_residx (a)
448434 set! (result, path[2 : end ], MLIR. IR. result (res, residx))
@@ -478,37 +464,22 @@ for (jlop, hloop, hlocomp, merge) in (
478464 (:(Base.:(<= )), :compare , " LE" , nothing ),
479465 (:(Base.:(< )), :compare , " LT" , nothing ),
480466)
481- @eval begin
482- function $ (jlop)(
483- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs:: TracedRArray{T,0} )
484- ) where {T}
485- return TracedRArray {Bool,0} (
486- (),
487- MLIR. IR. result (
488- MLIR. Dialects. stablehlo.$ hloop (
489- lhs. mlir_data,
490- rhs. mlir_data;
491- comparison_direction= MLIR. API. stablehloComparisonDirectionAttrGet (
492- MLIR. IR. context (), $ hlocomp
493- ),
467+ @eval function $ (jlop)(
468+ @nospecialize (lhs:: TracedRScalar{T} ), @nospecialize (rhs:: TracedRScalar{T} )
469+ ) where {T}
470+ return TracedRScalar {Bool} (
471+ (),
472+ MLIR. IR. result (
473+ MLIR. Dialects. stablehlo.$ (hloop)(
474+ lhs. mlir_data,
475+ rhs. mlir_data;
476+ comparison_direction= MLIR. API. stablehloComparisonDirectionAttrGet (
477+ MLIR. IR. context (), $ hlocomp
494478 ),
495- 1 ,
496479 ),
497- size (lhs),
498- )
499- end
500-
501- function $ (jlop)(
502- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs)
503- ) where {T}
504- return $ (jlop)(lhs, promote_to (lhs, rhs))
505- end
506-
507- function $ (jlop)(
508- @nospecialize (lhs), @nospecialize (rhs:: TracedRArray{T,0} )
509- ) where {T}
510- return $ (jlop)(promote_to (rhs, lhs), rhs)
511- end
480+ 1 ,
481+ ),
482+ )
512483 end
513484
514485 if merge != = nothing
@@ -598,7 +569,7 @@ function Base.mapreduce(
598569 fnbody = MLIR. IR. Block (in_tys, [MLIR. IR. Location () for arg in in_tys])
599570
600571 args = (
601- TracedRArray {T,0 } ((), MLIR. IR. argument (fnbody, i), ()) for
572+ TracedRScalar {T } ((), MLIR. IR. argument (fnbody, i), ()) for
602573 (i, ty) in enumerate (in_tys)
603574 )
604575
0 commit comments