@@ -23,6 +23,8 @@ Base.strides(x::TracedRArray) = Base.size_to_strides(1, size(x)...)
2323
2424Base. IndexStyle (:: Type{<:TracedRArray} ) = Base. IndexLinear ()
2525
26+ Base. elsize (:: Type{TracedRArray{T,N}} ) where {T,N} = sizeof (T)
27+
2628# This is required otherwise we will copy a tracedrarray each time
2729# we use it
2830Base. convert (T:: Type{<:TracedRArray} , x:: AbstractArray ) = Reactant. promote_to (T, x)
@@ -265,10 +267,6 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
265267 return print (io, " TracedRArray{" , T, " ," , N, " N}(" , X. paths, " , size=" , size (X), " )" )
266268end
267269
268- function Base. permutedims (A:: AnyTracedRArray{T,N} , perm) where {T,N}
269- return @opcall transpose (materialize_traced_array (A), Int64[perm... ])
270- end
271-
272270for (jlop, hloop, hlocomp, merge) in
273271 ((:(Base.:(== )), :compare , " EQ" , :all ), (:(Base.:(!= )), :compare , " NE" , :any ))
274272 @eval function $jlop (
@@ -279,6 +277,17 @@ for (jlop, hloop, hlocomp, merge) in
279277 end
280278end
281279
280+ # Override _parentsmatch to avoid pointer comparisons during tracing
281+ # Direct TracedRArray comparisons - they don't alias unless they're the same object
282+ Base. _parentsmatch (A:: TracedRArray , B:: TracedRArray ) = A === B
283+ # ReshapedArray comparisons - check if they share the same parent (more specific than StridedArray)
284+ function Base. _parentsmatch (
285+ A:: Base.ReshapedArray{<:TracedRNumber,<:Any,<:Union{TracedRArray,SubArray{<:TracedRNumber,<:Any,<:TracedRArray}}} ,
286+ B:: Base.ReshapedArray{<:TracedRNumber,<:Any,<:Union{TracedRArray,SubArray{<:TracedRNumber,<:Any,<:TracedRArray}}}
287+ )
288+ return Base. _parentsmatch (parent (A), parent (B))
289+ end
290+
282291function __default_init (
283292 :: Type{T} , :: Union{typeof(Base.min),typeof(Base.FastMath.min_fast)}
284293) where {T}
@@ -1348,4 +1357,15 @@ function unrolled_map(f::F, itr) where {F}
13481357 return result
13491358end
13501359
1360+ # permutedims for TracedRArrays and wrappers
1361+ function Base. permutedims (A:: AnyTracedRArray{T,N} , perm) where {T,N}
1362+ return @opcall transpose (materialize_traced_array (A), Int64[perm... ])
1363+ end
1364+
1365+ function Base. permutedims! (dest:: TracedRArray , src:: AnyTracedRArray , perm)
1366+ result = @opcall transpose (materialize_traced_array (src), Int64[perm... ])
1367+ TracedUtils. set_mlir_data! (dest, result. mlir_data)
1368+ return dest
1369+ end
1370+
13511371end
0 commit comments