11# reference implementation on the CPU
22# This acts as a wrapper around KernelAbstractions's parallel CPU
3- # functionality. It is useful for testing GPUArrays (and other packages)
3+ # functionality. It is useful for testing GPUArrays (and other packages)
44# when no GPU is present.
55# This file follows conventions from AMDGPU.jl
66
77module JLArrays
88
9+ export JLArray, JLVector, JLMatrix, jl, JLBackend
10+
911using GPUArrays
12+
1013using Adapt
14+
1115import KernelAbstractions
1216import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config
1317
14- export JLArray, JLVector, JLMatrix, jl, JLBackend
1518
1619#
1720# Device functionality
@@ -24,7 +27,6 @@ struct JLBackend <: KernelAbstractions.GPU
2427 JLBackend (;static:: Bool = false ) = new (static)
2528end
2629
27-
2830struct Adaptor end
2931jlconvert (arg) = adapt (Adaptor (), arg)
3032
3537Base. getindex (r:: JlRefValue ) = r. x
3638Adapt. adapt_structure (to:: Adaptor , r:: Base.RefValue ) = JlRefValue (adapt (to, r[]))
3739
38- mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
39- data:: DataRef{Vector{UInt8}}
40-
41- offset:: Int # offset of the data in the buffer, in number of elements
42-
43- dims:: Dims{N}
44-
45- # allocating constructor
46- function JLArray {T,N} (:: UndefInitializer , dims:: Dims{N} ) where {T,N}
47- check_eltype (T)
48- maxsize = prod (dims) * sizeof (T)
49- data = Vector {UInt8} (undef, maxsize)
50- ref = DataRef (data) do data
51- resize! (data, 0 )
52- end
53- obj = new {T,N} (ref, 0 , dims)
54- finalizer (unsafe_free!, obj)
55- end
56-
57- # low-level constructor for wrapping existing data
58- function JLArray {T,N} (ref:: DataRef{Vector{UInt8}} , dims:: Dims{N} ;
59- offset:: Int = 0 ) where {T,N}
60- check_eltype (T)
61- obj = new {T,N} (ref, offset, dims)
62- finalizer (unsafe_free!, obj)
63- end
64- end
65-
66- Adapt. adapt_storage (:: JLBackend , a:: Array ) = Adapt. adapt (JLArrays. JLArray, a)
67- Adapt. adapt_storage (:: JLBackend , a:: JLArrays.JLArray ) = a
68- Adapt. adapt_storage (:: KernelAbstractions.CPU , a:: JLArrays.JLArray ) = convert (Array, a)
40+ # # executed on-device
6941
7042# array type
7143
9163@inline Base. getindex (A:: JLDeviceArray , index:: Integer ) = getindex (typed_data (A), index)
9264@inline Base. setindex! (A:: JLDeviceArray , x, index:: Integer ) = setindex! (typed_data (A), x, index)
9365
66+
9467#
9568# Host abstractions
9669#
@@ -104,6 +77,34 @@ function check_eltype(T)
10477 end
10578end
10679
80+ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
81+ data:: DataRef{Vector{UInt8}}
82+
83+ offset:: Int # offset of the data in the buffer, in number of elements
84+
85+ dims:: Dims{N}
86+
87+ # allocating constructor
88+ function JLArray {T,N} (:: UndefInitializer , dims:: Dims{N} ) where {T,N}
89+ check_eltype (T)
90+ maxsize = prod (dims) * sizeof (T)
91+ data = Vector {UInt8} (undef, maxsize)
92+ ref = DataRef (data) do data
93+ resize! (data, 0 )
94+ end
95+ obj = new {T,N} (ref, 0 , dims)
96+ finalizer (unsafe_free!, obj)
97+ end
98+
99+ # low-level constructor for wrapping existing data
100+ function JLArray {T,N} (ref:: DataRef{Vector{UInt8}} , dims:: Dims{N} ;
101+ offset:: Int = 0 ) where {T,N}
102+ check_eltype (T)
103+ obj = new {T,N} (ref, offset, dims)
104+ finalizer (unsafe_free!, obj)
105+ end
106+ end
107+
107108unsafe_free! (a:: JLArray ) = GPUArrays. unsafe_free! (a. data)
108109
109110# conversion of untyped data to a typed Array
@@ -380,7 +381,10 @@ function (obj::Kernel{JLBackend})(args...; ndrange=nothing, workgroupsize=nothin
380381 device_args = jlconvert .(args)
381382 new_obj = convert_to_cpu (obj)
382383 new_obj (device_args... ; ndrange, workgroupsize)
383-
384384end
385385
386+ Adapt. adapt_storage (:: JLBackend , a:: Array ) = Adapt. adapt (JLArrays. JLArray, a)
387+ Adapt. adapt_storage (:: JLBackend , a:: JLArrays.JLArray ) = a
388+ Adapt. adapt_storage (:: KernelAbstractions.CPU , a:: JLArrays.JLArray ) = convert (Array, a)
389+
386390end
0 commit comments