Skip to content

Commit 1e02ac0

Browse files
committed
Fix performance
1 parent 8c8e87f commit 1e02ac0

File tree

4 files changed

+93
-36
lines changed

4 files changed

+93
-36
lines changed

src/chainrules.jl

+55-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing
2-
using Zygote: @adjoint
2+
using Base.Broadcast: broadcasted
3+
import Zygote: @adjoint, accum_sum, unbroadcast, Numeric, ∇getindex, _project
34

45
function contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N}
56
mapreduce(*, +, value(a), value(b))
@@ -56,3 +57,56 @@ ProjectTo(::T) where {T <: TaylorScalar} = ProjectTo{T}()
5657
(p::ProjectTo{T})(x::T) where {T <: TaylorScalar} = x
5758
ProjectTo(x::AbstractArray{T}) where {T <: TaylorScalar} = ProjectTo{AbstractArray}(; element=ProjectTo(zero(T)), axes=axes(x))
5859
(p::ProjectTo{AbstractArray{T}})(x::AbstractArray{T}) where {T <: TaylorScalar} = x
60+
accum_sum(xs::AbstractArray{T}; dims = :) where {T <: TaylorScalar} = sum(xs, dims = dims)
61+
62+
TaylorNumeric{T<:TaylorScalar} = Union{T, AbstractArray{<:T}}
63+
64+
@adjoint broadcasted(::typeof(+), xs::Union{Numeric, TaylorNumeric}...) = broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...)
65+
66+
struct TaylorOneElement{T,N,I,A} <: AbstractArray{T,N}
67+
val::T
68+
ind::I
69+
axes::A
70+
TaylorOneElement(val::T, ind::I, axes::A) where {T<:TaylorScalar, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes)
71+
end
72+
73+
Base.size(A::TaylorOneElement) = map(length, A.axes)
74+
Base.axes(A::TaylorOneElement) = A.axes
75+
Base.getindex(A::TaylorOneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T))
76+
77+
∇getindex(x::AbstractArray{T, N}, inds) where {T <: TaylorScalar, N} = dy -> begin
78+
dx = TaylorOneElement(dy, inds, axes(x))
79+
return (_project(x, dx), map(_->nothing, inds)...)
80+
end
81+
82+
@generated function mul_adjoint::TaylorScalar{T, N}, x::TaylorScalar{T, N}) where {T, N}
83+
return quote
84+
vΩ, vx = value(Ω), value(x)
85+
@inbounds TaylorScalar($([:(+($([:($(binomial(j - 1, i - 1)) * vΩ[$j] *
86+
vx[$(j + 1 - i)]) for j in i:N]...)))
87+
for i in 1:N]...))
88+
end
89+
end
90+
91+
rrule(::typeof(*), x::TaylorScalar) = rrule(identity, x)
92+
93+
function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar)
94+
function times_pullback2(Ω̇)
95+
ΔΩ = unthunk(Ω̇)
96+
return (NoTangent(), ProjectTo(x)(mul_adjoint(ΔΩ, y)), ProjectTo(y)(mul_adjoint(ΔΩ, x)))
97+
end
98+
return x * y, times_pullback2
99+
end
100+
101+
function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar, more::TaylorScalar...)
102+
Ω2, back2 = rrule(*, x, y)
103+
Ω3, back3 = rrule(*, Ω2, z)
104+
Ω4, back4 = rrule(*, Ω3, more...)
105+
function times_pullback4(Ω̇)
106+
Δ4 = back4(unthunk(Ω̇)) # (0, ΔΩ3, Δmore...)
107+
Δ3 = back3(Δ4[2]) # (0, ΔΩ2, Δz)
108+
Δ2 = back2(Δ3[2]) # (0, Δx, Δy)
109+
return (Δ2..., Δ3[3], Δ4[3:end]...)
110+
end
111+
return Ω4, times_pullback4
112+
end

src/codegen.jl

+24-24
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,30 @@
11
using ChainRulesCore
22
using SymbolicUtils, SymbolicUtils.Code
3-
using SymbolicUtils: Pow
3+
using SymbolicUtils: BasicSymbolic, Pow
44

5-
@scalar_rule +(x::Any) true
6-
@scalar_rule -(x::Any) -1
7-
@scalar_rule deg2rad(x::Any) deg2rad(one(x))
8-
@scalar_rule rad2deg(x::Any) rad2deg(one(x))
9-
@scalar_rule asin(x::Any) inv(sqrt(1 - x^2))
10-
@scalar_rule acos(x::Any) inv(-sqrt(1 - x^2))
11-
@scalar_rule atan(x::Any) inv(-(1 + x^2))
12-
@scalar_rule acot(x::Any) inv(-(1 + x^2))
13-
@scalar_rule acsc(x::Any) inv(x^2 * -sqrt(1 - x^-2))
14-
@scalar_rule asec(x::Any) inv(x^2 * sqrt(1 - x^-2))
15-
@scalar_rule log(x::Any) inv(x)
16-
@scalar_rule log10(x::Any) inv(log(10.0) * x)
17-
@scalar_rule log1p(x::Any) inv(x + 1)
18-
@scalar_rule log2(x::Any) inv(log(2.0) * x)
19-
@scalar_rule sinh(x::Any) cosh(x)
20-
@scalar_rule cosh(x::Any) sinh(x)
21-
@scalar_rule tanh(x::Any) 1-Ω^2
22-
@scalar_rule acosh(x::Any) inv(sqrt(x - 1) * sqrt(x + 1))
23-
@scalar_rule acoth(x::Any) inv(1 - x^2)
24-
@scalar_rule acsch(x::Any) inv(x^2 * -sqrt(1 + x^-2))
25-
@scalar_rule asech(x::Any) inv(x * -sqrt(1 - x^2))
26-
@scalar_rule asinh(x::Any) inv(sqrt(x^2 + 1))
27-
@scalar_rule atanh(x::Any) inv(1 - x^2)
5+
@scalar_rule +(x::BasicSymbolic) true
6+
@scalar_rule -(x::BasicSymbolic) -1
7+
@scalar_rule deg2rad(x::BasicSymbolic) deg2rad(one(x))
8+
@scalar_rule rad2deg(x::BasicSymbolic) rad2deg(one(x))
9+
@scalar_rule asin(x::BasicSymbolic) inv(sqrt(1 - x^2))
10+
@scalar_rule acos(x::BasicSymbolic) inv(-sqrt(1 - x^2))
11+
@scalar_rule atan(x::BasicSymbolic) inv(-(1 + x^2))
12+
@scalar_rule acot(x::BasicSymbolic) inv(-(1 + x^2))
13+
@scalar_rule acsc(x::BasicSymbolic) inv(x^2 * -sqrt(1 - x^-2))
14+
@scalar_rule asec(x::BasicSymbolic) inv(x^2 * sqrt(1 - x^-2))
15+
@scalar_rule log(x::BasicSymbolic) inv(x)
16+
@scalar_rule log10(x::BasicSymbolic) inv(log(10.0) * x)
17+
@scalar_rule log1p(x::BasicSymbolic) inv(x + 1)
18+
@scalar_rule log2(x::BasicSymbolic) inv(log(2.0) * x)
19+
@scalar_rule sinh(x::BasicSymbolic) cosh(x)
20+
@scalar_rule cosh(x::BasicSymbolic) sinh(x)
21+
@scalar_rule tanh(x::BasicSymbolic) 1-Ω^2
22+
@scalar_rule acosh(x::BasicSymbolic) inv(sqrt(x - 1) * sqrt(x + 1))
23+
@scalar_rule acoth(x::BasicSymbolic) inv(1 - x^2)
24+
@scalar_rule acsch(x::BasicSymbolic) inv(x^2 * -sqrt(1 + x^-2))
25+
@scalar_rule asech(x::BasicSymbolic) inv(x * -sqrt(1 - x^2))
26+
@scalar_rule asinh(x::BasicSymbolic) inv(sqrt(x^2 + 1))
27+
@scalar_rule atanh(x::BasicSymbolic) inv(1 - x^2)
2828

2929
dummy = (NoTangent(), 1)
3030
@syms t₁

src/primitive.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ end
156156
end
157157
end
158158

159-
raise(::T, df::S, t::TaylorScalar{T, N}) where {S <: Real, T <: Number, N} = df * t
159+
raise(::T, df::S, t::TaylorScalar{T, N}) where {S <: Number, T, N} = df * t
160160

161161
@generated function raiseinv(f::T, df::TaylorScalar{T, M},
162162
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N

src/scalar.jl

+13-10
Original file line numberDiff line numberDiff line change
@@ -5,45 +5,47 @@ import Base: convert, promote_rule
55
export TaylorScalar
66

77
"""
8-
TaylorScalar{T <: Number, N}
8+
TaylorScalar{T, N}
99
1010
Representation of Taylor polynomials.
1111
1212
# Fields
1313
1414
- `value::NTuple{N, T}`: i-th element of this stores the (i-1)-th derivative
1515
"""
16-
struct TaylorScalar{T <: Number, N}
16+
struct TaylorScalar{T, N}
1717
value::NTuple{N, T}
1818
end
1919

20-
@inline TaylorScalar(xs::Vararg{T, N}) where {T <: Number, N} = TaylorScalar(xs)
20+
TaylorOrNumber = Union{TaylorScalar, Number}
21+
22+
@inline TaylorScalar(xs::Vararg{T, N}) where {T, N} = TaylorScalar(xs)
2123

2224
"""
23-
TaylorScalar{T, N}(x::S) where {S <: Number, T <: Number, N}
25+
TaylorScalar{T, N}(x::T) where {T, N}
2426
2527
Construct a Taylor polynomial with zeroth order coefficient.
2628
"""
27-
@generated function TaylorScalar{T, N}(x::S) where {S <: Number, T <: Number, N}
29+
@generated function TaylorScalar{T, N}(x::T) where {T, N}
2830
return quote
2931
$(Expr(:meta, :inline))
3032
TaylorScalar((T(x), $(zeros(T, N - 1)...)))
3133
end
3234
end
3335

3436
"""
35-
TaylorScalar{T, N}(x::S, d::S) where {S <: Number, T <: Number, N}
37+
TaylorScalar{T, N}(x::T, d::T) where {T, N}
3638
3739
Construct a Taylor polynomial with zeroth and first order coefficient, acting as a seed.
3840
"""
39-
@generated function TaylorScalar{T, N}(x::S, d::S) where {S <: Number, T <: Number, N}
41+
@generated function TaylorScalar{T, N}(x::T, d::T) where {T, N}
4042
return quote
4143
$(Expr(:meta, :inline))
4244
TaylorScalar((T(x), T(d), $(zeros(T, N - 2)...)))
4345
end
4446
end
4547

46-
@generated function TaylorScalar{T, N}(t::TaylorScalar{T, M}) where {T <: Number, N, M}
48+
@generated function TaylorScalar{T, N}(t::TaylorScalar{T, M}) where {T, N, M}
4749
N <= M ? quote
4850
$(Expr(:meta, :inline))
4951
TaylorScalar(value(t)[1:N])
@@ -67,13 +69,14 @@ adjoint(t::TaylorScalar) = t
6769
conj(t::TaylorScalar) = t
6870

6971
function promote_rule(::Type{TaylorScalar{T, N}},
70-
::Type{S}) where {T <: Number, S <: Number, N}
72+
::Type{S}) where {T, S, N}
7173
TaylorScalar{promote_type(T, S), N}
7274
end
7375

7476
# Number-like convention (I patched them after removing <: Number)
7577

76-
convert(::Type{TaylorScalar{T, N}}, x::Number) where {T, N} = TaylorScalar{T, N}(x)
78+
convert(::Type{TaylorScalar{T, N}}, x::TaylorScalar{T, N}) where {T, N} = x
79+
convert(::Type{TaylorScalar{T, N}}, x::S) where {T, S, N} = TaylorScalar{T, N}(convert(T, x))
7780
for op in (:+, :-, :*, :/)
7881
@eval @inline $op(a::TaylorScalar, b::Number) = $op(promote(a, b)...)
7982
@eval @inline $op(a::Number, b::TaylorScalar) = $op(promote(a, b)...)

0 commit comments

Comments
 (0)