|
19 | 19 |
|
20 | 20 | TracedRArray{T,N}(x::TracedRArray{T,N}) where {T,N} = x |
21 | 21 |
|
22 | | -mutable struct TracedRNumber{T} <: RNumber{T} |
| 22 | +const ReactantPrimitives = Union{ |
| 23 | + Bool, |
| 24 | + Int8, |
| 25 | + UInt8, |
| 26 | + Int16, |
| 27 | + UInt16, |
| 28 | + Int32, |
| 29 | + UInt32, |
| 30 | + Int64, |
| 31 | + UInt64, |
| 32 | + Float16, |
| 33 | + Float32, |
| 34 | + # BFloat16, |
| 35 | + Float64, |
| 36 | + Complex{Float32}, |
| 37 | + Complex{Float64}, |
| 38 | +} |
| 39 | + |
| 40 | +# `<: ReactantPrimitives` ensures we don't end up with nested `TracedRNumber`s |
| 41 | +mutable struct TracedRNumber{T<:ReactantPrimitives} <: RNumber{T} |
23 | 42 | paths::Tuple |
24 | 43 | mlir_data::Union{Nothing,MLIR.IR.Value} |
25 | 44 |
|
@@ -214,14 +233,8 @@ function Base.transpose(A::AnyTracedRVecOrMat) |
214 | 233 | end |
215 | 234 | Base.adjoint(A::AnyTracedRVecOrMat{<:Real}) = transpose(A) |
216 | 235 |
|
217 | | -function Base.promote_rule( |
218 | | - ::Type{TracedRArray{T,N}}, ::Type{TracedRArray{S,N}} |
219 | | -) where {T,S,N} |
220 | | - return TracedRArray{Base.promote_type(T, S),N} |
221 | | -end |
222 | | - |
223 | | -function Base.promote_rule(::Type{T}, ::Type{TracedRArray{S,N}}) where {T,S,N} |
224 | | - return TracedRArray{Base.promote_type(T, S),N} |
| 236 | +function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{TracedRNumber{S}}) where {T,S} |
| 237 | + return TracedRNumber{Base.promote_type(T, S)} |
225 | 238 | end |
226 | 239 |
|
227 | 240 | function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S} |
@@ -326,8 +339,6 @@ function Base.ifelse( |
326 | 339 | ) |
327 | 340 | end |
328 | 341 |
|
329 | | -Base.abs2(x::Reactant.TracedRNumber{T}) where {T} = x * conj(x) |
330 | | - |
331 | 342 | function Base.literal_pow( |
332 | 343 | ::Base.RefValue{typeof(^)}, x::TracedRNumber{T}, ::Base.RefValue{Val{P}} |
333 | 344 | ) where {T,P} |
|
355 | 366 |
|
356 | 367 | struct TypeCast{T<:Number} <: Function end |
357 | 368 |
|
358 | | -function (::TypeCast{T})(x::TracedRArray{T2,0}) where {T,T2} |
359 | | - return promote_to(TracedRArray{T,0}, x) |
| 369 | +function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} |
| 370 | + return promote_to(TracedRNumber{T}, x) |
360 | 371 | end |
361 | 372 |
|
362 | 373 | elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:Number} = x |
@@ -556,8 +567,7 @@ function Base.mapreduce( |
556 | 567 | fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in in_tys]) |
557 | 568 |
|
558 | 569 | args = ( |
559 | | - TracedRNumber{T}((), MLIR.IR.argument(fnbody, i), ()) for |
560 | | - (i, ty) in enumerate(in_tys) |
| 570 | + TracedRNumber{T}((), MLIR.IR.argument(fnbody, i)) for (i, ty) in enumerate(in_tys) |
561 | 571 | ) |
562 | 572 |
|
563 | 573 | res = MLIR.IR.block!(fnbody) do |
@@ -708,6 +718,25 @@ function broadcast_to_size(arg::T, rsize) where {T<:Number} |
708 | 718 | ) |
709 | 719 | end |
710 | 720 |
|
| 721 | +function broadcast_to_size(arg::TracedRNumber, rsize) |
| 722 | + rsize == () && return arg |
| 723 | + mlirty = MLIR.IR.type(arg.mlir_data) |
| 724 | + return TracedRArray{eltype(arg),length(rsize)}( |
| 725 | + (), |
| 726 | + MLIR.IR.result( |
| 727 | + MLIR.Dialects.stablehlo.broadcast_in_dim( |
| 728 | + arg.mlir_data; |
| 729 | + result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)), |
| 730 | + broadcast_dimensions=MLIR.IR.DenseArrayAttribute([ |
| 731 | + Int64(i - 1) for i in rsize |
| 732 | + ]), |
| 733 | + ), |
| 734 | + 1, |
| 735 | + ), |
| 736 | + rsize, |
| 737 | + ) |
| 738 | +end |
| 739 | + |
711 | 740 | function broadcast_to_size(arg::AnyTracedRArray, rsize) |
712 | 741 | arg = materialize_traced_array(arg) |
713 | 742 | size(arg) == rsize && return arg |
|
0 commit comments