@@ -40,6 +40,8 @@ Base.propertynames(p::ProjectTo) = propertynames(backing(p))
40
40
backing (project:: ProjectTo ) = getfield (project, :info )
41
41
42
42
project_type (p:: ProjectTo{T} ) where {T} = T
43
+ project_type (:: Type{<:ProjectTo{T}} ) where {T} = T
44
+ project_type (_) = Any
43
45
44
46
function Base. show (io:: IO , project:: ProjectTo{T} ) where {T}
45
47
print (io, " ProjectTo{" )
@@ -142,42 +144,16 @@ end
142
144
# dx::AbstractArray (when both are possible), or the reverse. So for now we just pass them through:
143
145
(:: ProjectTo{T} )(dx:: Tangent{<:T} ) where {T} = dx
144
146
145
- # ####
146
- # #### A related utility which wants to live nearby
147
- # ####
148
-
149
- """
150
- is_non_differentiable(x) == is_non_differentiable(typeof(x))
151
-
152
- Returns `true` if `x` is known from its type not to have derivatives, else `false`.
153
-
154
- Should mostly agree with whether `ProjectTo(x)` maps to `AbstractZero`,
155
- which is what the fallback method checks. The exception is that it will not look
156
- inside abstractly typed containers like `x = Any[true, false]`.
157
- """
158
- is_non_differentiable (x) = is_non_differentiable (typeof (x))
159
-
160
- is_non_differentiable (:: Type{<:Number} ) = false
161
- is_non_differentiable (:: Type{<:NTuple{N,T}} ) where {N,T} = is_non_differentiable (T)
162
- is_non_differentiable (:: Type{<:AbstractArray{T}} ) where {T} = is_non_differentiable (T)
163
-
164
- function is_non_differentiable (:: Type{T} ) where {T} # fallback
165
- PT = Base. _return_type (ProjectTo, Tuple{T}) # might be Union{} if unstable
166
- return isconcretetype (PT) && PT <: ProjectTo{<:AbstractZero}
167
- end
168
-
169
147
# ####
170
148
# #### `Base`
171
149
# ####
172
150
173
151
# Bool
174
152
ProjectTo (:: Bool ) = ProjectTo {NoTangent} () # same projector as ProjectTo(::AbstractZero) above
175
- is_non_differentiable (:: Type{Bool} ) = true
176
153
177
154
# Other never-differentiable types
178
- for T in (:Symbol , :Char , :AbstractString , :RoundingMode , :IndexStyle )
155
+ for T in (:Symbol , :Char , :AbstractString , :RoundingMode , :IndexStyle , :Nothing )
179
156
@eval ProjectTo (:: $T ) = ProjectTo {NoTangent} ()
180
- @eval is_non_differentiable (:: Type{<:$T} ) = true
181
157
end
182
158
183
159
# Numbers
@@ -627,3 +603,40 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC)
627
603
invoke (project, Tuple{AbstractArray}, dx)
628
604
end
629
605
end
606
+
607
+ # ####
608
+ # #### A related utility which wants to live nearby
609
+ # ####
610
+
611
+ """
612
+ differential_type(x)
613
+ differential_type(typeof(x))
614
+
615
+ Testing `differential_type(x) <: AbstractZero` will tell you whether `x` is
616
+ known to be non-differentiable.
617
+
618
+ This relies on `ProjectTo(x)`, and the method accepting a type relies on type inference.
619
+ Thus it will not look inside abstractly typed containers such as `x = Any[true, false]`.
620
+
621
+ ```jldoctest
622
+ julia> differential_type(true)
623
+ NoTangent
624
+
625
+ julia> differential_type(Int)
626
+ Float64
627
+
628
+ julia> x = Any[true, false];
629
+
630
+ julia> differential_type(x)
631
+ NoTangent
632
+
633
+ julia> differential_type(typeof(x))
634
+ Any
635
+ ```
636
+ """
637
+ differential_type (x) = project_type (ProjectTo (x))
638
+
639
+ function differential_type (:: Type{T} ) where {T}
640
+ PT = Base. _return_type (ProjectTo, Tuple{T}) # might be Union{} if unstable
641
+ return isconcretetype (PT) ? project_type (PT) : Any
642
+ end
0 commit comments