Skip to content

Commit 259e29e

Browse files
Make RArray a subtype of DenseArray (#1696)
* Make `RArray` a subtype of `DenseArray` * Make `AnyTracedRArray` a `DenseArray` subtype * Fix permutedims errors * Fix errors * Simplify method definitions * Add _parentsmatch method
1 parent abba4f4 commit 259e29e

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

src/ConcreteRArray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ for jlop in (
200200
:(Base.:^),
201201
:(Base.:(==)),
202202
),
203-
T in (AbstractConcreteNumber, AbstractConcreteArray{<:Any,0})
203+
T in (AbstractConcreteNumber, AbstractConcreteArray{<:Number,0})
204204

205205
@eval begin
206206
$(jlop)(x::$(T), y::$(T)) = $(jlop)(to_number(x), to_number(y))

src/TracedRArray.jl

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ Base.strides(x::TracedRArray) = Base.size_to_strides(1, size(x)...)
2323

2424
Base.IndexStyle(::Type{<:TracedRArray}) = Base.IndexLinear()
2525

26+
Base.elsize(::Type{TracedRArray{T,N}}) where {T,N} = sizeof(T)
27+
2628
# This is required otherwise we will copy a tracedrarray each time
2729
# we use it
2830
Base.convert(T::Type{<:TracedRArray}, x::AbstractArray) = Reactant.promote_to(T, x)
@@ -265,10 +267,6 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
265267
return print(io, "TracedRArray{", T, ",", N, "N}(", X.paths, ", size=", size(X), ")")
266268
end
267269

268-
function Base.permutedims(A::AnyTracedRArray{T,N}, perm) where {T,N}
269-
return @opcall transpose(materialize_traced_array(A), Int64[perm...])
270-
end
271-
272270
for (jlop, hloop, hlocomp, merge) in
273271
((:(Base.:(==)), :compare, "EQ", :all), (:(Base.:(!=)), :compare, "NE", :any))
274272
@eval function $jlop(
@@ -279,6 +277,17 @@ for (jlop, hloop, hlocomp, merge) in
279277
end
280278
end
281279

280+
# Override _parentsmatch to avoid pointer comparisons during tracing
281+
# Direct TracedRArray comparisons - they don't alias unless they're the same object
282+
Base._parentsmatch(A::TracedRArray, B::TracedRArray) = A === B
283+
# ReshapedArray comparisons - check if they share the same parent (more specific than StridedArray)
284+
function Base._parentsmatch(
285+
A::Base.ReshapedArray{<:TracedRNumber,<:Any,<:Union{TracedRArray,SubArray{<:TracedRNumber,<:Any,<:TracedRArray}}},
286+
B::Base.ReshapedArray{<:TracedRNumber,<:Any,<:Union{TracedRArray,SubArray{<:TracedRNumber,<:Any,<:TracedRArray}}}
287+
)
288+
return Base._parentsmatch(parent(A), parent(B))
289+
end
290+
282291
function __default_init(
283292
::Type{T}, ::Union{typeof(Base.min),typeof(Base.FastMath.min_fast)}
284293
) where {T}
@@ -1348,4 +1357,15 @@ function unrolled_map(f::F, itr) where {F}
13481357
return result
13491358
end
13501359

1360+
# permutedims for TracedRArrays and wrappers
1361+
function Base.permutedims(A::AnyTracedRArray{T,N}, perm) where {T,N}
1362+
return @opcall transpose(materialize_traced_array(A), Int64[perm...])
1363+
end
1364+
1365+
function Base.permutedims!(dest::TracedRArray, src::AnyTracedRArray, perm)
1366+
result = @opcall transpose(materialize_traced_array(src), Int64[perm...])
1367+
TracedUtils.set_mlir_data!(dest, result.mlir_data)
1368+
return dest
1369+
end
1370+
13511371
end

src/Types.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ abstract type RNumber{T<:ReactantPrimitive} <: Number end
22

33
abstract type AbstractConcreteNumber{T} <: RNumber{T} end
44

5-
abstract type RArray{T,N} <: AbstractArray{T,N} end
5+
abstract type RArray{T,N} <: DenseArray{T,N} end
66

77
abstract type AbstractConcreteArray{T,N} <: RArray{T,N} end
88

@@ -52,6 +52,11 @@ mutable struct TracedRNumber{T} <: RNumber{T}
5252
end
5353
end
5454

55+
Base.elsize(::Type{TracedRNumber{T}}) where {T} = sizeof(T)
56+
Base.elsize(::Type{RNumber{T}}) where {T} = sizeof(T)
57+
Base.elsize(::Type{<:AbstractConcreteNumber{T}}) where {T} = sizeof(T)
58+
Base.elsize(::Type{<:AbstractConcreteArray{T}}) where {T} = sizeof(T)
59+
5560
function repath(x::TracedRNumber{T}, paths) where {T}
5661
return TracedRNumber{T}(paths, x.mlir_data)
5762
end

0 commit comments

Comments
 (0)