Skip to content

Commit 17236cc

Browse files
authored
Merge pull request #345 from JuliaGPU/tb/broadcast_instabilities
Properly error when broadcasting type-unstable functions.
2 parents 89bbe0d + 858e2a3 commit 17236cc

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

src/host/broadcast.jl

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,17 @@ function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle{0}})
3636
return @allowscalar dest[CartesianIndex()] # 0D broadcast needs to unwrap results
3737
end
3838

39-
# We purposefully only specialize `copyto!`, dependent packages need to make sure that they
40-
# can handle:
41-
# - `bc::Broadcast.Broadcasted{Style}`
42-
# - `ex::Broadcast.Extruded`
43-
# - `LinearAlgebra.Transpose{,<:AbstractGPUArray}` and `LinearAlgebra.Adjoint{,<:AbstractGPUArray}`, etc
44-
# as arguments to a kernel and that they do the right conversion.
45-
#
46-
# This Broadcast can be further customize by:
47-
# - `Broadcast.preprocess(dest::AbstractGPUArray, bc::Broadcasted{Nothing})` which allows for a
48-
# complete transformation based on the output type just at the end of the pipeline.
49-
# - `Broadcast.broadcasted(::Style, f)` selection of an implementation of `f` compatible
50-
# with `Style`
51-
#
52-
# For more information see the Base documentation.
39+
# we need to override the outer copy method to make sure we never fall back to scalar
40+
# iteration (see, e.g., CUDA.jl#145)
41+
@inline function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle})
42+
ElType = Broadcast.combine_eltypes(bc.f, bc.args)
43+
if !Base.isconcretetype(ElType)
44+
error("""GPU broadcast resulted in non-concrete element type $ElType.
45+
This probably means that the function you are broadcasting contains an error or type instability.""")
46+
end
47+
copyto!(similar(bc, ElType), bc)
48+
end
49+
5350
@inline function Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{Nothing})
5451
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
5552
isempty(dest) && return dest

test/testsuite/broadcasting.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
@testsuite "broadcasting" AT->begin
22
broadcasting(AT)
33
vec3(AT)
4+
5+
@testset "type instabilities" begin
6+
f(x) = x ? 1.0 : 0
7+
try
8+
f.(AT(rand(Bool, 1)))
9+
catch err
10+
@test err isa ErrorException
11+
@test contains(err.msg, "GPU broadcast resulted in non-concrete element type")
12+
end
13+
end
414
end
515

616
test_idx(idx, A::AbstractArray{T}) where T = A[idx] * T(2)

0 commit comments

Comments
 (0)