|
2 | 2 |
|
3 | 3 | using Base.Broadcast |
4 | 4 |
|
5 | | -import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate |
| 5 | +using Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate |
6 | 6 |
|
7 | 7 | # but make sure we don't dispatch to the optimized copy method that directly indexes |
8 | 8 | function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle{0}}) |
|
32 | 32 | return _copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest)))) |
33 | 33 | end |
34 | 34 |
|
35 | | -@inline Base.copyto!(dest::AnyGPUArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict |
| 35 | +@inline Base.copyto!(dest::AnyGPUArray, bc::Broadcasted{Nothing}) = |
| 36 | + _copyto!(dest, bc) # Keep it for ArrayConflict |
36 | 37 |
|
37 | | -@inline Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractGPUArrayStyle}) = _copyto!(dest, bc) |
| 38 | +@inline Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractGPUArrayStyle}) = |
| 39 | + _copyto!(dest, bc) |
38 | 40 |
|
39 | 41 | @inline function _copyto!(dest::AbstractArray, bc::Broadcasted) |
40 | 42 | axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) |
41 | 43 | isempty(dest) && return dest |
42 | | - bc′ = Broadcast.preprocess(dest, bc) |
43 | | - |
44 | | - # grid-stride kernel |
45 | | - function broadcast_kernel(ctx, dest, bc′, nelem) |
46 | | - i = 0 |
47 | | - while i < nelem |
48 | | - i += 1 |
49 | | - I = @cartesianidx(dest, i) |
50 | | - @inbounds dest[I] = bc′[I] |
| 44 | + bc = Broadcast.preprocess(dest, bc) |
| 45 | + |
| 46 | + broadcast_kernel = if ndims(dest) == 1 || |
| 47 | + (isa(IndexStyle(dest), IndexLinear) && |
| 48 | + isa(IndexStyle(bc), IndexLinear)) |
| 49 | + function (ctx, dest, bc, nelem) |
| 50 | + i = 1 |
| 51 | + while i <= nelem |
| 52 | + I = @linearidx(dest, i) |
| 53 | + @inbounds dest[I] = bc[I] |
| 54 | + i += 1 |
| 55 | + end |
| 56 | + return |
| 57 | + end |
| 58 | + else |
| 59 | + function (ctx, dest, bc, nelem) |
| 60 | + i = 0 |
| 61 | + while i < nelem |
| 62 | + i += 1 |
| 63 | + I = @cartesianidx(dest, i) |
| 64 | + @inbounds dest[I] = bc[I] |
| 65 | + end |
| 66 | + return |
51 | 67 | end |
52 | | - return |
53 | 68 | end |
| 69 | + |
54 | 70 | elements = length(dest) |
55 | 71 | elements_per_thread = typemax(Int) |
56 | | - heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc′, 1; |
| 72 | + heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc, 1; |
57 | 73 | elements, elements_per_thread) |
58 | 74 | config = launch_configuration(backend(dest), heuristic; |
59 | 75 | elements, elements_per_thread) |
60 | | - gpu_call(broadcast_kernel, dest, bc′, config.elements_per_thread; |
| 76 | + gpu_call(broadcast_kernel, dest, bc, config.elements_per_thread; |
61 | 77 | threads=config.threads, blocks=config.blocks) |
62 | 78 |
|
63 | 79 | return dest |
@@ -101,12 +117,15 @@ function Base.map!(f, dest::AnyGPUArray, xs::AbstractArray...) |
101 | 117 |
|
102 | 118 | # grid-stride kernel |
103 | 119 | function map_kernel(ctx, dest, bc, nelem) |
104 | | - for i in 1:nelem |
| 120 | + i = 1 |
| 121 | + while i <= nelem |
105 | 122 | j = linear_index(ctx, i) |
106 | 123 | j > common_length && return |
107 | 124 |
|
108 | 125 | J = CartesianIndices(axes(bc))[j] |
109 | 126 | @inbounds dest[j] = bc[J] |
| 127 | + |
| 128 | + i += 1 |
110 | 129 | end |
111 | 130 | return |
112 | 131 | end |
|
0 commit comments