Skip to content

Commit 7c8e040

Browse files
committed
change to return the type
1 parent 418cc18 commit 7c8e040

File tree

3 files changed

+52
-29
lines changed

3 files changed

+52
-29
lines changed

src/ChainRulesCore.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod
1111
export frule_via_ad, rrule_via_ad
1212
# definition helper macros
1313
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented
14-
export ProjectTo, canonicalize, unthunk # tangent operations
14+
export ProjectTo, differential_type, canonicalize, unthunk # tangent operations
1515
export add!! # gradient accumulation operations
16-
export ignore_derivatives, @ignore_derivatives, is_non_differentiable
16+
export ignore_derivatives, @ignore_derivatives
1717
# tangents
1818
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
1919

src/projection.jl

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ Base.propertynames(p::ProjectTo) = propertynames(backing(p))
4040
backing(project::ProjectTo) = getfield(project, :info)
4141

4242
project_type(p::ProjectTo{T}) where {T} = T
43+
project_type(::Type{<:ProjectTo{T}}) where {T} = T
44+
project_type(_) = Any
4345

4446
function Base.show(io::IO, project::ProjectTo{T}) where {T}
4547
print(io, "ProjectTo{")
@@ -142,42 +144,16 @@ end
142144
# dx::AbstractArray (when both are possible), or the reverse. So for now we just pass them through:
143145
(::ProjectTo{T})(dx::Tangent{<:T}) where {T} = dx
144146

145-
#####
146-
##### A related utility which wants to live nearby
147-
#####
148-
149-
"""
150-
is_non_differentiable(x) == is_non_differentiable(typeof(x))
151-
152-
Returns `true` if `x` is known from its type not to have derivatives, else `false`.
153-
154-
Should mostly agree with whether `ProjectTo(x)` maps to `AbstractZero`,
155-
which is what the fallback method checks. The exception is that it will not look
156-
inside abstractly typed containers like `x = Any[true, false]`.
157-
"""
158-
is_non_differentiable(x) = is_non_differentiable(typeof(x))
159-
160-
is_non_differentiable(::Type{<:Number}) = false
161-
is_non_differentiable(::Type{<:NTuple{N,T}}) where {N,T} = is_non_differentiable(T)
162-
is_non_differentiable(::Type{<:AbstractArray{T}}) where {T} = is_non_differentiable(T)
163-
164-
function is_non_differentiable(::Type{T}) where {T} # fallback
165-
PT = Base._return_type(ProjectTo, Tuple{T}) # might be Union{} if unstable
166-
return isconcretetype(PT) && PT <: ProjectTo{<:AbstractZero}
167-
end
168-
169147
#####
170148
##### `Base`
171149
#####
172150

173151
# Bool
174152
ProjectTo(::Bool) = ProjectTo{NoTangent}() # same projector as ProjectTo(::AbstractZero) above
175-
is_non_differentiable(::Type{Bool}) = true
176153

177154
# Other never-differentiable types
178-
for T in (:Symbol, :Char, :AbstractString, :RoundingMode, :IndexStyle)
155+
for T in (:Symbol, :Char, :AbstractString, :RoundingMode, :IndexStyle, :Nothing)
179156
@eval ProjectTo(::$T) = ProjectTo{NoTangent}()
180-
@eval is_non_differentiable(::Type{<:$T}) = true
181157
end
182158

183159
# Numbers
@@ -627,3 +603,40 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC)
627603
invoke(project, Tuple{AbstractArray}, dx)
628604
end
629605
end
606+
607+
#####
608+
##### A related utility which wants to live nearby
609+
#####
610+
611+
"""
612+
differential_type(x)
613+
differential_type(typeof(x))
614+
615+
Testing `differential_type(x) <: AbstractZero` will tell you whether `x` is
616+
known to be non-differentiable.
617+
618+
This relies on `ProjectTo(x)`, and the method accepting a type relies on type inference.
619+
Thus it will not look inside abstractly typed containers such as `x = Any[true, false]`.
620+
621+
```jldoctest
622+
julia> differential_type(true)
623+
NoTangent
624+
625+
julia> differential_type(Int)
626+
Float64
627+
628+
julia> x = Any[true, false];
629+
630+
julia> differential_type(x)
631+
NoTangent
632+
633+
julia> differential_type(typeof(x))
634+
Any
635+
```
636+
"""
637+
differential_type(x) = project_type(ProjectTo(x))
638+
639+
function differential_type(::Type{T}) where {T}
640+
PT = Base._return_type(ProjectTo, Tuple{T}) # might be Union{} if unstable
641+
return isconcretetype(PT) ? project_type(PT) : Any
642+
end

test/projection.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,3 +478,13 @@ struct NoSuperType end
478478
@test_broken 0 == @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64
479479
end
480480
end
481+
482+
@testset "differential_type" begin
483+
@test differential_type(true) == differential_type(Bool) == NoTangent
484+
@test differential_type(1) == differential_type(Int) == Float64
485+
tup = (false, :x, nothing)
486+
@test differential_type(tup) == differential_type(typeof(tup)) == NoTangent
487+
488+
@test differential_type(NoSuperType()) == differential_type(NoSuperType) == Any
489+
@test differential_type(Dual(1,2)) == differential_type(Dual) == Real
490+
end

0 commit comments

Comments
 (0)