@@ -36,20 +36,17 @@ function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle{0}})
36
36
return @allowscalar dest[CartesianIndex ()] # 0D broadcast needs to unwrap results
37
37
end
38
38
39
- # We purposefully only specialize `copyto!`, dependent packages need to make sure that they
40
- # can handle:
41
- # - `bc::Broadcast.Broadcasted{Style}`
42
- # - `ex::Broadcast.Extruded`
43
- # - `LinearAlgebra.Transpose{,<:AbstractGPUArray}` and `LinearAlgebra.Adjoint{,<:AbstractGPUArray}`, etc
44
- # as arguments to a kernel and that they do the right conversion.
45
- #
46
- # This Broadcast can be further customize by:
47
- # - `Broadcast.preprocess(dest::AbstractGPUArray, bc::Broadcasted{Nothing})` which allows for a
48
- # complete transformation based on the output type just at the end of the pipeline.
49
- # - `Broadcast.broadcasted(::Style, f)` selection of an implementation of `f` compatible
50
- # with `Style`
51
- #
52
- # For more information see the Base documentation.
39
+ # we need to override the outer copy method to make sure we never fall back to scalar
40
+ # iteration (see, e.g., CUDA.jl#145)
41
+ @inline function Broadcast. copy (bc:: Broadcasted{<:AbstractGPUArrayStyle} )
42
+ ElType = Broadcast. combine_eltypes (bc. f, bc. args)
43
+ if ! Base. isconcretetype (ElType)
44
+ error (""" GPU broadcast resulted in non-concrete element type $ElType .
45
+ This probably means that the function you are broadcasting contains an error or type instability.""" )
46
+ end
47
+ copyto! (similar (bc, ElType), bc)
48
+ end
49
+
53
50
@inline function Base. copyto! (dest:: BroadcastGPUArray , bc:: Broadcasted{Nothing} )
54
51
axes (dest) == axes (bc) || Broadcast. throwdm (axes (dest), axes (bc))
55
52
isempty (dest) && return dest
0 commit comments