@@ -36,20 +36,17 @@ function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle{0}})
3636 return @allowscalar dest[CartesianIndex ()] # 0D broadcast needs to unwrap results
3737end
3838
39- # We purposefully only specialize `copyto!`, dependent packages need to make sure that they
40- # can handle:
41- # - `bc::Broadcast.Broadcasted{Style}`
42- # - `ex::Broadcast.Extruded`
43- # - `LinearAlgebra.Transpose{,<:AbstractGPUArray}` and `LinearAlgebra.Adjoint{,<:AbstractGPUArray}`, etc
44- # as arguments to a kernel and that they do the right conversion.
45- #
46- # This Broadcast can be further customize by:
47- # - `Broadcast.preprocess(dest::AbstractGPUArray, bc::Broadcasted{Nothing})` which allows for a
48- # complete transformation based on the output type just at the end of the pipeline.
49- # - `Broadcast.broadcasted(::Style, f)` selection of an implementation of `f` compatible
50- # with `Style`
51- #
52- # For more information see the Base documentation.
39+ # we need to override the outer copy method to make sure we never fall back to scalar
40+ # iteration (see, e.g., CUDA.jl#145)
41+ @inline function Broadcast. copy (bc:: Broadcasted{<:AbstractGPUArrayStyle} )
42+ ElType = Broadcast. combine_eltypes (bc. f, bc. args)
43+ if ! Base. isconcretetype (ElType)
44+ error (""" GPU broadcast resulted in non-concrete element type $ElType .
45+ This probably means that the function you are broadcasting contains an error or type instability.""" )
46+ end
47+ copyto! (similar (bc, ElType), bc)
48+ end
49+
5350@inline function Base. copyto! (dest:: BroadcastGPUArray , bc:: Broadcasted{Nothing} )
5451 axes (dest) == axes (bc) || Broadcast. throwdm (axes (dest), axes (bc))
5552 isempty (dest) && return dest
0 commit comments