|
1 | 1 | import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing
|
2 |
| -using Zygote: @adjoint |
| 2 | +using Base.Broadcast: broadcasted |
| 3 | +import Zygote: @adjoint, accum_sum, unbroadcast, Numeric, ∇getindex, _project |
3 | 4 |
|
4 | 5 | function contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N}
|
5 | 6 | mapreduce(*, +, value(a), value(b))
|
@@ -56,3 +57,56 @@ ProjectTo(::T) where {T <: TaylorScalar} = ProjectTo{T}()
|
56 | 57 | (p::ProjectTo{T})(x::T) where {T <: TaylorScalar} = x
|
57 | 58 | ProjectTo(x::AbstractArray{T}) where {T <: TaylorScalar} = ProjectTo{AbstractArray}(; element=ProjectTo(zero(T)), axes=axes(x))
|
58 | 59 | (p::ProjectTo{AbstractArray{T}})(x::AbstractArray{T}) where {T <: TaylorScalar} = x
|
| 60 | +accum_sum(xs::AbstractArray{T}; dims = :) where {T <: TaylorScalar} = sum(xs, dims = dims) |
| 61 | + |
| 62 | +TaylorNumeric{T<:TaylorScalar} = Union{T, AbstractArray{<:T}} |
| 63 | + |
| 64 | +@adjoint broadcasted(::typeof(+), xs::Union{Numeric, TaylorNumeric}...) = broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...) |
| 65 | + |
| 66 | +struct TaylorOneElement{T,N,I,A} <: AbstractArray{T,N} |
| 67 | + val::T |
| 68 | + ind::I |
| 69 | + axes::A |
| 70 | + TaylorOneElement(val::T, ind::I, axes::A) where {T<:TaylorScalar, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes) |
| 71 | +end |
| 72 | + |
| 73 | +Base.size(A::TaylorOneElement) = map(length, A.axes) |
| 74 | +Base.axes(A::TaylorOneElement) = A.axes |
| 75 | +Base.getindex(A::TaylorOneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T)) |
| 76 | + |
| 77 | +∇getindex(x::AbstractArray{T, N}, inds) where {T <: TaylorScalar, N} = dy -> begin |
| 78 | + dx = TaylorOneElement(dy, inds, axes(x)) |
| 79 | + return (_project(x, dx), map(_->nothing, inds)...) |
| 80 | +end |
| 81 | + |
| 82 | +@generated function mul_adjoint(Ω::TaylorScalar{T, N}, x::TaylorScalar{T, N}) where {T, N} |
| 83 | + return quote |
| 84 | + vΩ, vx = value(Ω), value(x) |
| 85 | + @inbounds TaylorScalar($([:(+($([:($(binomial(j - 1, i - 1)) * vΩ[$j] * |
| 86 | + vx[$(j + 1 - i)]) for j in i:N]...))) |
| 87 | + for i in 1:N]...)) |
| 88 | + end |
| 89 | +end |
| 90 | + |
| 91 | +rrule(::typeof(*), x::TaylorScalar) = rrule(identity, x) |
| 92 | + |
| 93 | +function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar) |
| 94 | + function times_pullback2(Ω̇) |
| 95 | + ΔΩ = unthunk(Ω̇) |
| 96 | + return (NoTangent(), ProjectTo(x)(mul_adjoint(ΔΩ, y)), ProjectTo(y)(mul_adjoint(ΔΩ, x))) |
| 97 | + end |
| 98 | + return x * y, times_pullback2 |
| 99 | +end |
| 100 | + |
| 101 | +function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar, more::TaylorScalar...) |
| 102 | + Ω2, back2 = rrule(*, x, y) |
| 103 | + Ω3, back3 = rrule(*, Ω2, z) |
| 104 | + Ω4, back4 = rrule(*, Ω3, more...) |
| 105 | + function times_pullback4(Ω̇) |
| 106 | + Δ4 = back4(unthunk(Ω̇)) # (0, ΔΩ3, Δmore...) |
| 107 | + Δ3 = back3(Δ4[2]) # (0, ΔΩ2, Δz) |
| 108 | + Δ2 = back2(Δ3[2]) # (0, Δx, Δy) |
| 109 | + return (Δ2..., Δ3[3], Δ4[3:end]...) |
| 110 | + end |
| 111 | + return Ω4, times_pullback4 |
| 112 | +end |
0 commit comments