@@ -176,22 +176,14 @@ _dual_safearg(x::Ref{<:Numeric{<:Real}}) = true
176
176
_dual_safearg (x:: Union{Type,Val,Symbol} ) = true # non-differentiable types
177
177
_dual_safearg (x) = false
178
178
179
- # This is Broadcast.combine_eltypes but with dual eltypes:
180
- _combine_dual_eltypes (f, args:: Tuple ) =
181
- Broadcast. promote_typejoin_union (Base. _return_type (f, map (_dual_eltype, args)))
182
- _dual_eltype (x:: Numeric{T} ) where {T<: Real } = Dual{Nothing, T, 1 } # typeof(Dual(one(T),true))
183
- _dual_eltype (x) = eltype (x)
184
-
185
179
@adjoint function broadcasted (:: AbstractArrayStyle , f:: F , args... ) where {F}
186
- TD = _combine_dual_eltypes (f, args)
180
+ T = Broadcast . combine_eltypes (f, args)
187
181
# Avoid generic broadcasting in two easy cases:
188
- if TD <: Dual && isconcretetype (TD)
189
- if _dual_purefun (F) && all (_dual_safearg, args)
190
- y, back = broadcast_forward (f, args... )
191
- return y, ȳ -> (nothing , nothing , back (ȳ)... )
192
- end
193
- elseif TD <: Real && isconcretetype (TD)
194
- return f .(args... ), _-> nothing
182
+ if T == Bool
183
+ return f .(args... ), _-> nothing
184
+ elseif T <: Real && isconcretetype (T) && _dual_purefun (F) && all (_dual_safearg, args)
185
+ y, back = broadcast_forward (f, args... )
186
+ return y, ȳ -> (nothing , nothing , back (ȳ)... )
195
187
end
196
188
len = inclen (args)
197
189
y∂b = _broadcast ((x... ) -> _pullback (__context__, f, x... ), args... )
0 commit comments