Skip to content

Commit 73205e2

Browse files
committed
FixedSizeArray -> ConcreteArray
1 parent 460dd4c commit 73205e2

File tree

4 files changed

+139
-10
lines changed

4 files changed

+139
-10
lines changed

ext/ReactantFixedSizeArraysExt.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,25 @@ using Reactant: TracedRArray, TracedRNumber, Ops
66
using ReactantCore: ReactantCore
77

88
function Reactant.traced_type_inner(
9-
@nospecialize(_::Type{FixedSizeArrays.FixedSizeArray{T, N, Memory{I}}}),
9+
@nospecialize(_::Type{FixedSizeArrays.FixedSizeArrayDefault{T, N}}),
1010
seen,
1111
@nospecialize(mode::Reactant.TraceMode),
1212
@nospecialize(track_numbers::Type),
1313
@nospecialize(sharding),
1414
@nospecialize(runtime)
15-
) where {T, N, I}
15+
) where {T, N}
1616
T2 = Reactant.TracedRNumber{T}
17-
I2 = Reactant.TracedRNumber{I}
18-
return FixedSizeArrays.FixedSizeArray{T2, N, Memory{I2}}
17+
return FixedSizeArrays.FixedSizeArrayDefault{T2, N}
1918
end
2019

2120
Base.@nospecializeinfer function Reactant.make_tracer(
2221
seen,
23-
@nospecialize(prev::FixedSizeArrays.FixedSizeArray{T, N, Memory{I}}),
22+
@nospecialize(prev::FixedSizeArrays.FixedSizeArrayDefault{T, N}),
2423
@nospecialize(path),
2524
mode; kwargs...
26-
) where {T, N, I}
27-
return FixedSizeArrays.FixedSizeArray(
28-
Reactant.make_tracer(
29-
seen, parent(prev), (path..., 1), mode; kwargs..., track_numbers=Number
30-
)
25+
) where {T, N}
26+
return Reactant.make_tracer(
27+
seen, parent(prev), (path..., 1), mode; kwargs..., track_numbers=Number
3128
)
3229
end
3330

src/Tracing.jl

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1812,6 +1812,95 @@ Base.@nospecializeinfer function make_tracer(
18121812
return res
18131813
end
18141814

1815+
1816+
Base.@nospecializeinfer function make_tracer(
1817+
seen,
1818+
@nospecialize(prev::Memory),
1819+
@nospecialize(path),
1820+
mode;
1821+
@nospecialize(track_numbers::Type = Union{}),
1822+
@nospecialize(sharding = Sharding.NoSharding()),
1823+
@nospecialize(runtime = nothing),
1824+
@nospecialize(device = nothing),
1825+
@nospecialize(client = nothing),
1826+
kwargs...,
1827+
)
1828+
RT = Core.Typeof(prev)
1829+
# XXX: If someone wants to shard the same array with different shardings, we need to
1830+
# somehow handle this correctly... Right now we just use the first sharding.
1831+
if mode != NoStopTracedTrack && haskey(seen, prev)
1832+
if mode == TracedToTypes
1833+
visited = seen[prev]
1834+
push!(path, visited)
1835+
return nothing
1836+
end
1837+
return seen[prev]
1838+
end
1839+
if eltype(RT) <: ReactantPrimitive
1840+
if mode == ArrayToConcrete
1841+
runtime isa Val{:PJRT} &&
1842+
(return seen[prev] = ConcretePJRTArray(prev; sharding, device, client))
1843+
runtime isa Val{:IFRT} &&
1844+
(return seen[prev] = ConcreteIFRTArray(prev; sharding, device, client))
1845+
error("Unsupported runtime $runtime")
1846+
elseif mode == TracedToTypes
1847+
# Original array can get mutated so we store a copy:
1848+
push!(path, copy(prev))
1849+
seen[prev] = VisitedObject(length(seen) + 1)
1850+
return nothing
1851+
end
1852+
elseif mode == TracedToTypes
1853+
push!(path, RT)
1854+
for I in eachindex(prev)
1855+
if isassigned(prev, I)
1856+
pv = prev[I]
1857+
make_tracer(
1858+
seen,
1859+
pv,
1860+
path,
1861+
mode;
1862+
track_numbers,
1863+
sharding,
1864+
runtime,
1865+
device,
1866+
client,
1867+
kwargs...,
1868+
)
1869+
end
1870+
end
1871+
return nothing
1872+
end
1873+
TT = traced_type(eltype(RT), Val(mode), track_numbers, sharding, runtime)
1874+
newa = Array{TT,ndims(RT)}(undef, size(prev))
1875+
seen[prev] = newa
1876+
same = true
1877+
for I in eachindex(prev)
1878+
if isassigned(prev, I)
1879+
pv = prev[I]
1880+
nv = make_tracer(
1881+
seen,
1882+
pv,
1883+
append_path(path, I),
1884+
mode;
1885+
track_numbers,
1886+
sharding=Base.getproperty(sharding, I),
1887+
runtime,
1888+
device,
1889+
client,
1890+
kwargs...,
1891+
)
1892+
if pv !== nv
1893+
same = false
1894+
end
1895+
@inbounds newa[I] = nv
1896+
end
1897+
end
1898+
if same
1899+
seen[prev] = prev
1900+
return prev
1901+
end
1902+
return newa
1903+
end
18151904
Base.@nospecializeinfer function make_tracer(
18161905
seen,
18171906
@nospecialize(prev::Sharding.Mesh),

src/Types.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,20 @@ function ConcretePJRTArray(
228228
return ConcretePJRTArray{T,N,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo)
229229
end
230230

231+
function ConcretePJRTArray(
232+
data::Memory{T};
233+
client::Union{Nothing,XLA.PJRT.Client}=nothing,
234+
idx::Union{Int,Nothing}=nothing,
235+
device::Union{Nothing,XLA.PJRT.Device}=nothing,
236+
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
237+
) where {T}
238+
theclient, thedevice = _select_client_and_device(client, idx, device, sharding)
239+
sharded_data, shardinfo = sharding(theclient, thedevice, data)
240+
shape = size(data)
241+
nsharded = length(sharded_data)
242+
return ConcretePJRTArray{T,1,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo)
243+
end
244+
231245
Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data)
232246
XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data)
233247
function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber})
@@ -356,6 +370,19 @@ function ConcreteIFRTArray(
356370
return ConcreteIFRTArray{T,N,typeof(shardinfo)}(sharded_data, shape, shardinfo, padding)
357371
end
358372

373+
function ConcreteIFRTArray(
374+
data::Memory{T};
375+
client::Union{Nothing,XLA.IFRT.Client}=nothing,
376+
idx::Union{Int,Nothing}=nothing,
377+
device::Union{Nothing,XLA.IFRT.Device}=nothing,
378+
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
379+
) where {T}
380+
theclient, thedevice = _select_client_and_device(client, idx, device, sharding)
381+
sharded_data, shardinfo, padding = sharding(theclient, nothing, data)
382+
shape = size(data)
383+
return ConcreteIFRTArray{T,1,typeof(shardinfo)}(sharded_data, shape, shardinfo)
384+
end
385+
359386
# Assemble data from multiple arrays. Needed in distributed setting where each process wont
360387
# have enough host memory to hold all the arrays. We assume that the data is only provided
361388
# for all of the addressable devices.

src/xla/PJRT/Buffer.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,22 @@ function Buffer(client::Client, array::Array{T,N}, device::Device) where {T,N}
2121
return Buffer(buffer)
2222
end
2323

24+
25+
function Buffer(client::Client, memory::Memory{T}, device::Device) where {T}
26+
sizear = collect(Int64, reverse(size(memory)))
27+
buffer = GC.@preserve memory sizear begin
28+
@ccall MLIR.API.mlir_c.ArrayFromHostBuffer(
29+
client.client::Ptr{Cvoid},
30+
pointer(memory)::Ptr{T},
31+
XLA.primitive_type(T)::UInt64,
32+
1::Csize_t,
33+
pointer(sizear)::Ptr{Int64},
34+
device.device::Ptr{Cvoid},
35+
)::Ptr{Cvoid}
36+
end
37+
return Buffer(buffer)
38+
end
39+
2440
function Base.similar(a::Buffer)
2541
buffer = GC.@preserve a begin
2642
@ccall MLIR.API.mlir_c.UninitPJRTBuffer(

0 commit comments

Comments
 (0)