|
47 | 47 | @inline function _copyto!(dest::AbstractArray, bc::Broadcasted) |
48 | 48 | axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) |
49 | 49 | isempty(dest) && return dest |
50 | | - |
51 | | - # to help Enzyme.jl, we won't pass the broadcasted object directly |
52 | | - # but instead pass its arguments and reconstruct the object device-side |
53 | 50 | bc = Broadcast.preprocess(dest, bc) |
54 | | - bcstyle = @static if VERSION >= v"1.10-" |
55 | | - bc.style |
56 | | - else |
57 | | - typeof(BroadcastStyle(typeof(bc))) |
58 | | - end |
59 | 51 |
|
60 | 52 | broadcast_kernel = if ndims(dest) == 1 || |
61 | 53 | (isa(IndexStyle(dest), IndexLinear) && |
62 | 54 | isa(IndexStyle(bc), IndexLinear)) |
63 | | - function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...) |
64 | | - bc′ = @static if VERSION >= v"1.10-" |
65 | | - Broadcasted(bcstyle, bcf, bcargs, bcaxes) |
66 | | - else |
67 | | - Broadcasted{bcstyle}(bcf, bcargs, bcaxes) |
68 | | - end |
69 | | - |
| 55 | + function (ctx, dest, bc, nelem) |
70 | 56 | i = 1 |
71 | 57 | while i <= nelem |
72 | 58 | I = @linearidx(dest, i) |
73 | | - @inbounds dest[I] = bc′[I] |
| 59 | + @inbounds dest[I] = bc[I] |
74 | 60 | i += 1 |
75 | 61 | end |
76 | 62 | return |
77 | 63 | end |
78 | 64 | else |
79 | | - function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...) |
80 | | - bc′ = @static if VERSION >= v"1.10-" |
81 | | - Broadcasted(bcstyle, bcf, bcargs, bcaxes) |
82 | | - else |
83 | | - Broadcasted{bcstyle}(bcf, bcargs, bcaxes) |
84 | | - end |
85 | | - |
| 65 | + function (ctx, dest, bc, nelem) |
86 | 66 | i = 0 |
87 | 67 | while i < nelem |
88 | 68 | i += 1 |
89 | 69 | I = @cartesianidx(dest, i) |
90 | | - @inbounds dest[I] = bc′[I] |
| 70 | + @inbounds dest[I] = bc[I] |
91 | 71 | end |
92 | 72 | return |
93 | 73 | end |
94 | 74 | end |
95 | 75 |
|
96 | 76 | elements = length(dest) |
97 | 77 | elements_per_thread = typemax(Int) |
98 | | - heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, 1, |
99 | | - bcstyle, bc.f, bc.axes, bc.args...; |
| 78 | + heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc, 1; |
100 | 79 | elements, elements_per_thread) |
101 | 80 | config = launch_configuration(backend(dest), heuristic; |
102 | 81 | elements, elements_per_thread) |
103 | | - gpu_call(broadcast_kernel, dest, config.elements_per_thread::Int, |
104 | | - bcstyle, bc.f, bc.axes, bc.args...; |
| 82 | + gpu_call(broadcast_kernel, dest, bc, config.elements_per_thread; |
105 | 83 | threads=config.threads, blocks=config.blocks) |
106 | 84 |
|
107 | 85 | if eltype(dest) <: BrokenBroadcast |
|
0 commit comments