@@ -192,18 +192,20 @@ function ConcretePJRTArray(
192192 return ConcretePJRTArray {T,N,nsharded,typeof(shardinfo)} (sharded_data, shape, shardinfo)
193193end
194194
195- function ConcretePJRTArray (
196- data:: Memory{T} ;
197- client:: Union{Nothing,XLA.PJRT.Client} = nothing ,
198- idx:: Union{Int,Nothing} = nothing ,
199- device:: Union{Nothing,XLA.PJRT.Device} = nothing ,
200- sharding:: Sharding.AbstractSharding = Sharding. NoSharding (),
201- ) where {T}
202- theclient, thedevice = _select_client_and_device (client, idx, device, sharding)
203- sharded_data, shardinfo = sharding (theclient, thedevice, data)
204- shape = size (data)
205- nsharded = length (sharded_data)
206- return ConcretePJRTArray {T,1,nsharded,typeof(shardinfo)} (sharded_data, shape, shardinfo)
195+ if isdefined (Base, :Memory )
196+ function ConcretePJRTArray (
197+ data:: Memory{T} ;
198+ client:: Union{Nothing,XLA.PJRT.Client} = nothing ,
199+ idx:: Union{Int,Nothing} = nothing ,
200+ device:: Union{Nothing,XLA.PJRT.Device} = nothing ,
201+ sharding:: Sharding.AbstractSharding = Sharding. NoSharding (),
202+ ) where {T}
203+ theclient, thedevice = _select_client_and_device (client, idx, device, sharding)
204+ sharded_data, shardinfo = sharding (theclient, thedevice, data)
205+ shape = size (data)
206+ nsharded = length (sharded_data)
207+ return ConcretePJRTArray {T,1,nsharded,typeof(shardinfo)} (sharded_data, shape, shardinfo)
208+ end
207209end
208210
209211Base. wait (x:: Union{ConcretePJRTArray,ConcretePJRTNumber} ) = foreach (wait, x. data)
@@ -334,17 +336,19 @@ function ConcreteIFRTArray(
334336 return ConcreteIFRTArray {T,N,typeof(shardinfo)} (sharded_data, shape, shardinfo, padding)
335337end
336338
337- function ConcreteIFRTArray (
338- data:: Memory{T} ;
339- client:: Union{Nothing,XLA.IFRT.Client} = nothing ,
340- idx:: Union{Int,Nothing} = nothing ,
341- device:: Union{Nothing,XLA.IFRT.Device} = nothing ,
342- sharding:: Sharding.AbstractSharding = Sharding. NoSharding (),
343- ) where {T}
344- theclient, thedevice = _select_client_and_device (client, idx, device, sharding)
345- sharded_data, shardinfo, padding = sharding (theclient, nothing , data)
346- shape = size (data)
347- return ConcreteIFRTArray {T,1,typeof(shardinfo)} (sharded_data, shape, shardinfo)
339+ if isdefined (Base, :Memory )
340+ function ConcreteIFRTArray (
341+ data:: Memory{T} ;
342+ client:: Union{Nothing,XLA.IFRT.Client} = nothing ,
343+ idx:: Union{Int,Nothing} = nothing ,
344+ device:: Union{Nothing,XLA.IFRT.Device} = nothing ,
345+ sharding:: Sharding.AbstractSharding = Sharding. NoSharding (),
346+ ) where {T}
347+ theclient, thedevice = _select_client_and_device (client, idx, device, sharding)
348+ sharded_data, shardinfo, padding = sharding (theclient, nothing , data)
349+ shape = size (data)
350+ return ConcreteIFRTArray {T,1,typeof(shardinfo)} (sharded_data, shape, shardinfo)
351+ end
348352end
349353
350354# Assemble data from multiple arrays. Needed in distributed setting where each process wont
0 commit comments