Skip to content

Commit cbd5e75

Browse files
committed
Safeguard Memory for v1.10
1 parent 15eca0b commit cbd5e75

File tree

3 files changed

+111
-104
lines changed

3 files changed

+111
-104
lines changed

src/Tracing.jl

Lines changed: 70 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1812,95 +1812,97 @@ Base.@nospecializeinfer function make_tracer(
18121812
return res
18131813
end
18141814

1815-
1816-
Base.@nospecializeinfer function make_tracer(
1817-
seen,
1818-
@nospecialize(prev::Memory),
1819-
@nospecialize(path),
1820-
mode;
1821-
@nospecialize(track_numbers::Type = Union{}),
1822-
@nospecialize(sharding = Sharding.NoSharding()),
1823-
@nospecialize(runtime = nothing),
1824-
@nospecialize(device = nothing),
1825-
@nospecialize(client = nothing),
1826-
kwargs...,
1827-
)
1828-
RT = Core.Typeof(prev)
1829-
# XXX: If someone wants to shard the same array with different shardings, we need to
1830-
# somehow handle this correctly... Right now we just use the first sharding.
1831-
if mode != NoStopTracedTrack && haskey(seen, prev)
1832-
if mode == TracedToTypes
1833-
visited = seen[prev]
1834-
push!(path, visited)
1835-
return nothing
1815+
if isdefined(Base, :Memory)
1816+
Base.@nospecializeinfer function make_tracer(
1817+
seen,
1818+
@nospecialize(prev::Memory),
1819+
@nospecialize(path),
1820+
mode;
1821+
@nospecialize(track_numbers::Type = Union{}),
1822+
@nospecialize(sharding = Sharding.NoSharding()),
1823+
@nospecialize(runtime = nothing),
1824+
@nospecialize(device = nothing),
1825+
@nospecialize(client = nothing),
1826+
kwargs...,
1827+
)
1828+
RT = Core.Typeof(prev)
1829+
# XXX: If someone wants to shard the same array with different shardings, we need to
1830+
# somehow handle this correctly... Right now we just use the first sharding.
1831+
if mode != NoStopTracedTrack && haskey(seen, prev)
1832+
if mode == TracedToTypes
1833+
visited = seen[prev]
1834+
push!(path, visited)
1835+
return nothing
1836+
end
1837+
return seen[prev]
18361838
end
1837-
return seen[prev]
1838-
end
1839-
if eltype(RT) <: ReactantPrimitive
1840-
if mode == ArrayToConcrete
1841-
runtime isa Val{:PJRT} &&
1842-
(return seen[prev] = ConcretePJRTArray(prev; sharding, device, client))
1843-
runtime isa Val{:IFRT} &&
1844-
(return seen[prev] = ConcreteIFRTArray(prev; sharding, device, client))
1845-
error("Unsupported runtime $runtime")
1839+
if eltype(RT) <: ReactantPrimitive
1840+
if mode == ArrayToConcrete
1841+
runtime isa Val{:PJRT} &&
1842+
(return seen[prev] = ConcretePJRTArray(prev; sharding, device, client))
1843+
runtime isa Val{:IFRT} &&
1844+
(return seen[prev] = ConcreteIFRTArray(prev; sharding, device, client))
1845+
error("Unsupported runtime $runtime")
1846+
elseif mode == TracedToTypes
1847+
# Original array can get mutated so we store a copy:
1848+
push!(path, copy(prev))
1849+
seen[prev] = VisitedObject(length(seen) + 1)
1850+
return nothing
1851+
end
18461852
elseif mode == TracedToTypes
1847-
# Original array can get mutated so we store a copy:
1848-
push!(path, copy(prev))
1849-
seen[prev] = VisitedObject(length(seen) + 1)
1853+
push!(path, RT)
1854+
for I in eachindex(prev)
1855+
if isassigned(prev, I)
1856+
pv = prev[I]
1857+
make_tracer(
1858+
seen,
1859+
pv,
1860+
path,
1861+
mode;
1862+
track_numbers,
1863+
sharding,
1864+
runtime,
1865+
device,
1866+
client,
1867+
kwargs...,
1868+
)
1869+
end
1870+
end
18501871
return nothing
18511872
end
1852-
elseif mode == TracedToTypes
1853-
push!(path, RT)
1873+
TT = traced_type(eltype(RT), Val(mode), track_numbers, sharding, runtime)
1874+
newa = Array{TT,ndims(RT)}(undef, size(prev))
1875+
seen[prev] = newa
1876+
same = true
18541877
for I in eachindex(prev)
18551878
if isassigned(prev, I)
18561879
pv = prev[I]
1857-
make_tracer(
1880+
nv = make_tracer(
18581881
seen,
18591882
pv,
1860-
path,
1883+
append_path(path, I),
18611884
mode;
18621885
track_numbers,
1863-
sharding,
1886+
sharding=Base.getproperty(sharding, I),
18641887
runtime,
18651888
device,
18661889
client,
18671890
kwargs...,
18681891
)
1892+
if pv !== nv
1893+
same = false
1894+
end
1895+
@inbounds newa[I] = nv
18691896
end
18701897
end
1871-
return nothing
1872-
end
1873-
TT = traced_type(eltype(RT), Val(mode), track_numbers, sharding, runtime)
1874-
newa = Array{TT,ndims(RT)}(undef, size(prev))
1875-
seen[prev] = newa
1876-
same = true
1877-
for I in eachindex(prev)
1878-
if isassigned(prev, I)
1879-
pv = prev[I]
1880-
nv = make_tracer(
1881-
seen,
1882-
pv,
1883-
append_path(path, I),
1884-
mode;
1885-
track_numbers,
1886-
sharding=Base.getproperty(sharding, I),
1887-
runtime,
1888-
device,
1889-
client,
1890-
kwargs...,
1891-
)
1892-
if pv !== nv
1893-
same = false
1894-
end
1895-
@inbounds newa[I] = nv
1898+
if same
1899+
seen[prev] = prev
1900+
return prev
18961901
end
1902+
return newa
18971903
end
1898-
if same
1899-
seen[prev] = prev
1900-
return prev
1901-
end
1902-
return newa
19031904
end
1905+
19041906
Base.@nospecializeinfer function make_tracer(
19051907
seen,
19061908
@nospecialize(prev::Sharding.Mesh),

src/Types.jl

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -228,18 +228,20 @@ function ConcretePJRTArray(
228228
return ConcretePJRTArray{T,N,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo)
229229
end
230230

231-
function ConcretePJRTArray(
232-
data::Memory{T};
233-
client::Union{Nothing,XLA.PJRT.Client}=nothing,
234-
idx::Union{Int,Nothing}=nothing,
235-
device::Union{Nothing,XLA.PJRT.Device}=nothing,
236-
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
237-
) where {T}
238-
theclient, thedevice = _select_client_and_device(client, idx, device, sharding)
239-
sharded_data, shardinfo = sharding(theclient, thedevice, data)
240-
shape = size(data)
241-
nsharded = length(sharded_data)
242-
return ConcretePJRTArray{T,1,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo)
231+
if isdefined(Base, :Memory)
232+
function ConcretePJRTArray(
233+
data::Memory{T};
234+
client::Union{Nothing,XLA.PJRT.Client}=nothing,
235+
idx::Union{Int,Nothing}=nothing,
236+
device::Union{Nothing,XLA.PJRT.Device}=nothing,
237+
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
238+
) where {T}
239+
theclient, thedevice = _select_client_and_device(client, idx, device, sharding)
240+
sharded_data, shardinfo = sharding(theclient, thedevice, data)
241+
shape = size(data)
242+
nsharded = length(sharded_data)
243+
return ConcretePJRTArray{T,1,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo)
244+
end
243245
end
244246

245247
Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data)
@@ -370,17 +372,19 @@ function ConcreteIFRTArray(
370372
return ConcreteIFRTArray{T,N,typeof(shardinfo)}(sharded_data, shape, shardinfo, padding)
371373
end
372374

373-
function ConcreteIFRTArray(
374-
data::Memory{T};
375-
client::Union{Nothing,XLA.IFRT.Client}=nothing,
376-
idx::Union{Int,Nothing}=nothing,
377-
device::Union{Nothing,XLA.IFRT.Device}=nothing,
378-
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
379-
) where {T}
380-
theclient, thedevice = _select_client_and_device(client, idx, device, sharding)
381-
sharded_data, shardinfo, padding = sharding(theclient, nothing, data)
382-
shape = size(data)
383-
return ConcreteIFRTArray{T,1,typeof(shardinfo)}(sharded_data, shape, shardinfo)
375+
if isdefined(Base, :Memory)
376+
function ConcreteIFRTArray(
377+
data::Memory{T};
378+
client::Union{Nothing,XLA.IFRT.Client}=nothing,
379+
idx::Union{Int,Nothing}=nothing,
380+
device::Union{Nothing,XLA.IFRT.Device}=nothing,
381+
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
382+
) where {T}
383+
theclient, thedevice = _select_client_and_device(client, idx, device, sharding)
384+
sharded_data, shardinfo, padding = sharding(theclient, nothing, data)
385+
shape = size(data)
386+
return ConcreteIFRTArray{T,1,typeof(shardinfo)}(sharded_data, shape, shardinfo)
387+
end
384388
end
385389

386390
# Assemble data from multiple arrays. Needed in distributed setting where each process wont

src/xla/PJRT/Buffer.jl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,21 @@ function Buffer(client::Client, array::Array{T,N}, device::Device) where {T,N}
2121
return Buffer(buffer)
2222
end
2323

24-
25-
function Buffer(client::Client, memory::Memory{T}, device::Device) where {T}
26-
sizear = collect(Int64, reverse(size(memory)))
27-
buffer = GC.@preserve memory sizear begin
28-
@ccall MLIR.API.mlir_c.ArrayFromHostBuffer(
29-
client.client::Ptr{Cvoid},
30-
pointer(memory)::Ptr{T},
31-
XLA.primitive_type(T)::UInt64,
32-
1::Csize_t,
33-
pointer(sizear)::Ptr{Int64},
34-
device.device::Ptr{Cvoid},
35-
)::Ptr{Cvoid}
24+
if isdefined(Base, :Memory)
25+
function Buffer(client::Client, memory::Memory{T}, device::Device) where {T}
26+
sizear = collect(Int64, reverse(size(memory)))
27+
buffer = GC.@preserve memory sizear begin
28+
@ccall MLIR.API.mlir_c.ArrayFromHostBuffer(
29+
client.client::Ptr{Cvoid},
30+
pointer(memory)::Ptr{T},
31+
XLA.primitive_type(T)::UInt64,
32+
1::Csize_t,
33+
pointer(sizear)::Ptr{Int64},
34+
device.device::Ptr{Cvoid},
35+
)::Ptr{Cvoid}
36+
end
37+
return Buffer(buffer)
3638
end
37-
return Buffer(buffer)
3839
end
3940

4041
function Base.similar(a::Buffer)

0 commit comments

Comments
 (0)