Skip to content

Commit 94712cc

Browse files
committed
revert some of that due to 20% slowdown
1 parent 7ebb75b commit 94712cc

File tree

1 file changed

+6
-14
lines changed

1 file changed

+6
-14
lines changed

src/lib/broadcast.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -176,22 +176,14 @@ _dual_safearg(x::Ref{<:Numeric{<:Real}}) = true
176176
_dual_safearg(x::Union{Type,Val,Symbol}) = true # non-differentiable types
177177
_dual_safearg(x) = false
178178

179-
# This is Broadcast.combine_eltypes but with dual eltypes:
180-
_combine_dual_eltypes(f, args::Tuple) =
181-
Broadcast.promote_typejoin_union(Base._return_type(f, map(_dual_eltype, args)))
182-
_dual_eltype(x::Numeric{T}) where {T<:Real} = Dual{Nothing, T, 1} # typeof(Dual(one(T),true))
183-
_dual_eltype(x) = eltype(x)
184-
185179
@adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
186-
TD = _combine_dual_eltypes(f, args)
180+
T = Broadcast.combine_eltypes(f, args)
187181
# Avoid generic broadcasting in two easy cases:
188-
if TD <: Dual && isconcretetype(TD)
189-
if _dual_purefun(F) && all(_dual_safearg, args)
190-
y, back = broadcast_forward(f, args...)
191-
return y, ȳ -> (nothing, nothing, back(ȳ)...)
192-
end
193-
elseif TD <: Real && isconcretetype(TD)
194-
return f.(args...), _->nothing
182+
if T == Bool
183+
return f.(args...), _->nothing
184+
elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args)
185+
y, back = broadcast_forward(f, args...)
186+
return y, ȳ -> (nothing, nothing, back(ȳ)...)
195187
end
196188
len = inclen(args)
197189
y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)

0 commit comments

Comments
 (0)