Skip to content

Commit 0363dd1

Browse files
committed
add specialization for TrackedArray
1 parent b48fe60 commit 0363dd1

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

src/PreallocationTools.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,18 @@ end
8484

8585
# override the [] method
8686
function Base.getindex(b::LazyBufferCache, u::T) where {T <: AbstractArray}
87+
s = b.sizemap(size(u)) # required buffer size
88+
buf = get!(b.bufs, (T, s)) do
89+
similar(u, s) # buffer to allocate if it was not found in b.bufs
90+
end::T # declare type since b.bufs dictionary is untyped
91+
return buf
92+
end
93+
94+
function Base.getindex(b::LazyBufferCache, u::ReverseDiff.TrackedArray)
8795
s = b.sizemap(size(u)) # required buffer size
8896
buf = get!(b.bufs, (T, s)) do
8997
# declare type since b.bufs dictionary is untyped
90-
zero(u)::T # buffer to allocate if it was not found in b.bufs
98+
similar(u, s)::T # buffer to allocate if it was not found in b.bufs
9199
end
92100
return buf
93101
end

0 commit comments

Comments
 (0)