diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index fa8a99fe..dbe43ed2 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -13,6 +13,7 @@ struct ∂☆new{N}; end # we split out the 1st order derivative as a special case for performance # but the nth order case does also work for this function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...) + @info "∂☆new{1}" B supertype(B) typeof(xs) primal_args = map(primal, xs) the_primal = _construct(B, primal_args) tangent_tup = map(first_partial, xs) @@ -23,8 +24,9 @@ function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...) tangent_nt = NamedTuple{names}(tangent_tup) StructuralTangent{B}(tangent_nt) end + the_final_partial = maybe_construct_natural_tangent(B, the_partial) B2 = typeof(the_primal) # HACK: if the_primal actually has types in it then we want to make sure we get DataType not Type(...) - return TaylorBundle{1, B2}(the_primal, (the_partial,)) + return TaylorBundle{1, B2}(the_primal, (the_final_partial,)) end function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N} @@ -41,7 +43,8 @@ function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N} tangent_nt = NamedTuple{names}(tangent_tup) StructuralTangent{B}(tangent_nt) end - return tangent + + return maybe_construct_natural_tangent(B, tangent) end return TaylorBundle{N, B}(the_primal, the_partials) end @@ -50,6 +53,28 @@ _construct(::Type{B}, args) where B<:Tuple = B(args) # Hack for making things that do not have public constructors constructable: @generated _construct(B::Type, args) = Expr(:splatnew, :B, :args) + +maybe_construct_natural_tangent(::Type, structural_tangent) = structural_tangent +for BaseSpaceType in (Number, AbstractArray{<:Number}) + @eval function maybe_construct_natural_tangent(::Type{B}, structural_tangent) where B<:$BaseSpaceType + try + # TODO: should this use `_construct` ? + # TODO: is this right? + unwrap_tup(x::Tangent{<:Tuple}) = ChainRulesCore.backing(x) + unwrap_tup(x) = x + field_tangents = map(unwrap_tup, structural_tangent) + B(field_tangents...) + catch + error( + "`struct` types that subtype `$($BaseSpaceType)` are generally expected to provide default constructors (one arg per field), and to be usable as their own tangent type.\n" * + "If they are not please overload the frule for the constructor: `ChainRulesCore.frule((_, dargs...), ::Type{$B}, args...)` " * + "and make it return whatever tangent type you want. But be warned this is off the beaten track. Here be dragons 🐉" + ) + end + end +end + + @generated (::∂☆new{N})(B::Type) where {N} = return :(zero_bundle{$N}()($(Expr(:new, :B)))) # Sometimes we don't know whether or not we need to the ZeroBundle when doing diff --git a/test/forward.jl b/test/forward.jl index f8040639..0a43feee 100644 --- a/test/forward.jl +++ b/test/forward.jl @@ -1,6 +1,7 @@ module forward_tests using Diffractor -using Diffractor: TaylorBundle, ZeroBundle, ∂☆ +using Diffractor: TaylorBundle, ZeroBundle, DNEBundle, ∂☆ +using Diffractor: first_partial, primal using ChainRules using ChainRulesCore using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad @@ -173,6 +174,32 @@ end end +@testset "custom number types" begin + struct CustomNumber <: Number + val::Float64 + end + + double_and_custom_num(x) = CustomNumber(2.0*x) + let var"'" = Diffractor.PrimeDerivativeFwd + @test double_and_custom_num'(100.0) == CustomNumber(2.0) + end +end + +@testset "custom array type" begin + struct MyLittleStaticVector{N, T} <: AbstractVector{T} + val::NTuple{N, T} + end + Base.size(::MyLittleStaticVector{N}) where N = (N,) + Base.getindex(x::MyLittleStaticVector, ii::Int) = x.val[ii] + + once_twice_three_times(x) = MyLittleStaticVector((x, 2x, 3x)) + @assert once_twice_three_times(10.0) == MyLittleStaticVector((10.0, 20.0, 30.0)) + + 🥯 = ∂☆{1}()(DNEBundle{1}(once_twice_three_times), TaylorBundle{1}(10.0, (1.0,))) + @test primal(🥯) = MyLittleStaticVector((10.0, 20.0, 30.0)) + @test first_partial(🥯) == MyLittleStaticVector((1.0, 2.0, 3.0)) +end + @testset "taylor_compatible" begin taylor_compatible = Diffractor.taylor_compatible