@@ -228,18 +228,20 @@ function ConcretePJRTArray(
228228 return ConcretePJRTArray {T,N,nsharded,typeof(shardinfo)} (sharded_data, shape, shardinfo)
229229end
230230
231- function ConcretePJRTArray (
232- data:: Memory{T} ;
233- client:: Union{Nothing,XLA.PJRT.Client} = nothing ,
234- idx:: Union{Int,Nothing} = nothing ,
235- device:: Union{Nothing,XLA.PJRT.Device} = nothing ,
236- sharding:: Sharding.AbstractSharding = Sharding. NoSharding (),
237- ) where {T}
238- theclient, thedevice = _select_client_and_device (client, idx, device, sharding)
239- sharded_data, shardinfo = sharding (theclient, thedevice, data)
240- shape = size (data)
241- nsharded = length (sharded_data)
242- return ConcretePJRTArray {T,1,nsharded,typeof(shardinfo)} (sharded_data, shape, shardinfo)
231+ if isdefined (Base, :Memory )
232+ function ConcretePJRTArray (
233+ data:: Memory{T} ;
234+ client:: Union{Nothing,XLA.PJRT.Client} = nothing ,
235+ idx:: Union{Int,Nothing} = nothing ,
236+ device:: Union{Nothing,XLA.PJRT.Device} = nothing ,
237+ sharding:: Sharding.AbstractSharding = Sharding. NoSharding (),
238+ ) where {T}
239+ theclient, thedevice = _select_client_and_device (client, idx, device, sharding)
240+ sharded_data, shardinfo = sharding (theclient, thedevice, data)
241+ shape = size (data)
242+ nsharded = length (sharded_data)
243+ return ConcretePJRTArray {T,1,nsharded,typeof(shardinfo)} (sharded_data, shape, shardinfo)
244+ end
243245end
244246
245247Base. wait (x:: Union{ConcretePJRTArray,ConcretePJRTNumber} ) = foreach (wait, x. data)
@@ -370,17 +372,19 @@ function ConcreteIFRTArray(
370372 return ConcreteIFRTArray {T,N,typeof(shardinfo)} (sharded_data, shape, shardinfo, padding)
371373end
372374
373- function ConcreteIFRTArray (
374- data:: Memory{T} ;
375- client:: Union{Nothing,XLA.IFRT.Client} = nothing ,
376- idx:: Union{Int,Nothing} = nothing ,
377- device:: Union{Nothing,XLA.IFRT.Device} = nothing ,
378- sharding:: Sharding.AbstractSharding = Sharding. NoSharding (),
379- ) where {T}
380- theclient, thedevice = _select_client_and_device (client, idx, device, sharding)
381- sharded_data, shardinfo, padding = sharding (theclient, nothing , data)
382- shape = size (data)
383- return ConcreteIFRTArray {T,1,typeof(shardinfo)} (sharded_data, shape, shardinfo)
375+ if isdefined (Base, :Memory )
376+ function ConcreteIFRTArray (
377+ data:: Memory{T} ;
378+ client:: Union{Nothing,XLA.IFRT.Client} = nothing ,
379+ idx:: Union{Int,Nothing} = nothing ,
380+ device:: Union{Nothing,XLA.IFRT.Device} = nothing ,
381+ sharding:: Sharding.AbstractSharding = Sharding. NoSharding (),
382+ ) where {T}
383+ theclient, thedevice = _select_client_and_device (client, idx, device, sharding)
384+ sharded_data, shardinfo, padding = sharding (theclient, nothing , data)
385+ shape = size (data)
386+ return ConcreteIFRTArray {T,1,typeof(shardinfo)} (sharded_data, shape, shardinfo)
387+ end
384388end
385389
386390# Assemble data from multiple arrays. Needed in distributed setting where each process wont
0 commit comments