Skip to content

Commit b4c8780

Browse files
committed
Add theory
1 parent 9ef7cdb commit b4c8780

File tree

9 files changed

+201
-96
lines changed

9 files changed

+201
-96
lines changed

docs/make.jl

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ makedocs(;
1515
assets = String[]),
1616
pages = [
1717
"Home" => "index.md",
18+
"Theory" => "theory.md",
1819
"API" => "api.md"
1920
])
2021

docs/src/theory.md

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
```@meta
2+
CurrentModule = TaylorDiff
3+
```
4+
5+
# Theory
6+
7+
TaylorDiff.jl is an operator-overloading based forward-mode automatic differentiation (AD) package. "Forward-mode" implies that the basic capability of this package is that, for function $f:\mathbb R^n\to\mathbb R^m$, place to evaluate derivative $x\in\mathbb R^n$ and direction $l\in\mathbb R^n$, we compute
8+
$$
9+
f(x),\partial f(x)\times v,\partial^2f(x)\times v\times v,\cdots,\partial^pf(x)\times v\times\cdots\times v
10+
$$
11+
i.e., the function value and the directional derivative up to order $p$. This notation might be unfamiliar to Julia users that had experience with other AD packages, but $\partial f(x)$ is simply the jacobian $J$, and $\partial f(x)\times v$ is simply the Jacobian-vector product (jvp). In other words, this is a simple generalization of Jacobian-vector product to Hessian-vector-vector product, and to even higher orders.
12+
13+
The main advantage of doing this instead of doing $p$ first-order Jacobian-vector products is that nesting first-order AD results in expential scaling w.r.t $p$, while this method, also known as Taylor mode, should be (almost) linear scaling w.r.t $p$. We will see the reason of this claim later.
14+
15+
In order to achieve this, assuming that $f$ is a nested function $f_k\circ\cdots\circ f_2\circ f_1$, where each $f_i$ is a basic and simple function, or called "primitives". We need to figure out how to propagate the derivatives through each step. In first order AD, this is achieved by the "dual" pair $x_0+x_1\varepsilon$, where $\varepsilon^2=0$, and for each primitive we make a method overload
16+
$$
17+
f(x_0+x_1\varepsilon)=f(x_0)+\partial f(x_0) x_1\varepsilon
18+
$$
19+
Similarly in higher-order AD, we need for each primitive a method overload for a truncated Taylor polynomial up to order $p$, and in this polynomial we will use $t$ instead of $\varepsilon$ to denote the sensitivity. "Truncated" means $t^{p+1}=0$, similar as what we defined for dual numbers. So
20+
$$
21+
f(x_0+x_1t+x_2t^2+\cdots+x_pt^p)=?
22+
$$
23+
What is the math expression that we should put into the question mark? That specific expression is called the "pushforward rule", and we will talk about how to derive the pushforward rule below.
24+
25+
## Arithmetic of polynomials
26+
27+
Before deriving pushforward rules, let's first introduce several basic properties of polynomials.
28+
29+
If $x(t)$ and $y(t)$ are both truncated Taylor polynomials, i.e.
30+
$$
31+
\begin{aligned}
32+
x&=x_0+x_1t+\cdots+x_pt^p\\
33+
y&=y_0+y_1t+\cdots+y_pt^p
34+
\end{aligned}
35+
$$
36+
Then it's obvious that the polynomial addition and subtraction should be
37+
$$
38+
(x\pm y)_k=x_k\pm y_k
39+
$$
40+
And with some derivation we can also get the polynomial multiplication rule
41+
$$
42+
(x\times y)_k=\sum_{i=0}^kx_iy_{k-i}
43+
$$
44+
The polynomial division rule is less obvious, but if $x/y=z$, then equivalently $x=yz$, i.e.
45+
$$
46+
\left(\sum_{i=0}^py_it^i\right)\left(\sum_{i=0}^pz_it^i\right)=\sum_{i=0}^px_it^i
47+
$$
48+
if we relate the coefficient of $t^k$ on both sides we get
49+
$$
50+
\sum_{i=0}^k z_iy_{k-i}=x_k
51+
$$
52+
so, equivalently,
53+
$$
54+
z_k=\frac1{y_0}\left(x_k-\sum_{i=0}^{k-1}z_iy_{k-1}\right)
55+
$$
56+
This is a recurrence relation, which means that we can first get $z_0=x_0/y_0$, and then get $z_1$ using $z_0$, and then get $z_2$ using $z_0,z_1$ etc.
57+
58+
## Pushforward rule for elementary functions
59+
60+
Let's now consider how to derive the pushforward rule for elementary functions. We will use $\exp$ and $\log$ as two examples.
61+
62+
If $x(t)$ is a polynomial and we want to get $e(t)=\exp(x(t))$, we can actually get that by formulating an ordinary differential equation:
63+
$$
64+
e'(t)=\exp(x(t))x'(t);\quad e_0=\exp(x_0)
65+
$$
66+
If we expand both $e$ and $x$ in the equation, we will get
67+
$$
68+
\sum_{i=1}^pie_it^{i-1}=\left(\sum_{i=0}^{p-1} e_it^i\right)\left(\sum_{i=1}^pix_it^{i-1}\right)
69+
$$
70+
relating the coefficient of $t^{k-1}$ on both sides, we get
71+
$$
72+
ke_k=\sum_{i=0}^{k-1}e_i\times (k-i)x_{k-i}
73+
$$
74+
This is, again, a recurrence relation, so we can get $e_1,\cdots,e_p$ step-by-step.
75+
76+
If $x(t)$ is a polynomial and we want to get $l(t)=\log(x(t))$, we can actually get that by formulating an ordinary differential equation:
77+
$$
78+
l'(t)=\frac1xx'(t);\quad l_0=\log(x_0)
79+
$$
80+
If we expand both $l$ and $x$ in the equation, the RHS is simply polynomial divisions, and we get
81+
$$
82+
l_k=\frac1{x_0}\left(x_k-\frac1k\sum_{i=1}^{k-1}il_ix_{k-j}\right)
83+
$$
84+
85+
---
86+
87+
Now notice the difference between the rule for $\exp$ and $\log$: the derivative of exponentiation is itself, so we can obtain from recurrence relation; the derivative of logarithm is $1/x$, an algebraic expression in $x$, so it can be directly computed. Similarly, we have $(\tan x)'=1+\tan^2x$ but $(\arctan x)'=(1+x^2)^{-1}$. We summarize (omitting proof) that
88+
89+
- Every $\exp$-like function (like $\sin$, $\cos$, $\tan$, $\sinh$, ...)'s derivative is somehow recursive
90+
- Every $\log$-like function (like $\arcsin$, $\arccos$, $\arctan$, $\operatorname{arcsinh}$, ...)'s derivative is algebraic
91+
92+
So all of the elementary functions have an easy pushforward rule that can be computed within $O(p^2)$ time. Note that this is an elegant and straightforward corollary from the definition of "elementary function" in differential algebra.
93+
94+
## Generic pushforward rule
95+
96+
For a generic $f(x)$, if we don't bother deriving the specific recurrence rule for it, we can still automatically generate pushforward rule in the following manner. Let's denote the derivative of $f$ w.r.t $x$ to be $d(x)$, then for $f(t)=f(x(t))$ we have
97+
$$
98+
f'(t)=d(x(t))x'(t);\quad f(0)=f(x_0)
99+
$$
100+
when we expand $f$ and $x$ up to order $p$ into this equation, we notice that only order $p-1$ is needed for $d(x(t))$. In other words, we turn a problem of finding $p$-th order pushforward for $f$, to a problem of finding $p-1$-th order pushforward for $d$, and we can recurse down to the first order. The first-order derivative expressions are captured from ChainRules.jl, which made this process fully automatic.
101+
102+
This strategy is in principle equivalent to nesting first-order differentiation, which could potentially leads to exponential scaling; however, in practice there is a huge difference. This generation of pushforward rule happens at **compile time**, which gives the compiler a chance to check redundant expressions and optimize it down to quadratic time. Compiler has stack limits but this should work for at least up to order 100.
103+
104+
In the current implementation of TaylorDiff.jl, all $\log$-like functions' pushforward rules are generated by this strategy, since their derivatives are simple algebraic expressions; some $\exp$-like functions, like sinh, is also generated; the most-often-used several $\exp$-like functions are hand-written with hand-derived recurrence relations.
105+
106+
If you find that the code generated by this strategy is slow, please file an issue and we will look into it.

src/derivative.jl

+19-38
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
1+
export derivative, derivative!, derivatives
12

2-
export derivative, derivative!, derivatives, make_seed
3+
# Added to help Zygote infer types
4+
@inline make_seed(x::T, l::T, ::Val{P}) where {T <: Real, P} = TaylorScalar{P}(x, l)
5+
@inline make_seed(x::A, l::A, ::Val{P}) where {A <: AbstractArray, P} = broadcast(
6+
make_seed, x, l, Val{P}())
37

48
"""
9+
derivative(f, x, ::Val{P})
510
derivative(f, x, l, ::Val{P})
611
derivative(f!, y, x, l, ::Val{P})
712
8-
Computes `P`-th directional derivative of `f` w.r.t. vector `x` in direction `l`.
13+
Computes `P`-th directional derivative of `f` w.r.t. vector `x` in direction `l`. If `x` is a Number, the direction `l` can be omitted.
914
"""
1015
function derivative end
1116

17+
@inline derivative(f, x::Number, p::Val{P}) where {P} = extract_derivative(
18+
derivatives(f, x, one(x), p), p)
19+
@inline derivative(f, x, l, p::Val{P}) where {P} = extract_derivative(
20+
derivatives(f, x, l, p), p)
21+
@inline derivative(f!, y, x, l, p::Val{P}) where {P} = extract_derivative(
22+
derivatives(f!, y, x, l, p), p)
23+
1224
"""
1325
derivative!(result, f, x, l, ::Val{P})
1426
derivative!(result, f!, y, x, l, ::Val{P})
@@ -17,6 +29,11 @@ In-place derivative calculation APIs. `result` is expected to be pre-allocated a
1729
"""
1830
function derivative! end
1931

32+
@inline derivative!(result, f, x, l, p::Val{P}) where {P} = extract_derivative!(
33+
result, derivatives(f, x, l, p), p)
34+
@inline derivative!(result, f!, y, x, l, p::Val{P}) where {P} = extract_derivative!(
35+
result, derivatives(f!, y, x, l, p), p)
36+
2037
"""
2138
derivatives(f, x, l, ::Val{P})
2239
derivatives(f!, y, x, l, ::Val{P})
@@ -25,43 +42,7 @@ Computes all derivatives of `f` at `x` up to order `P`.
2542
"""
2643
function derivatives end
2744

28-
# Convenience wrapper for adding unit seed to the input
29-
30-
@inline derivative(f, x, p::Int64) = derivative(f, x, broadcast(one, x), p)
31-
32-
# Convenience wrappers for converting ps to value types
33-
# and forward work to core APIs
34-
35-
@inline derivative(f, x, l, p::Int64) = derivative(f, x, l, Val{p}())
36-
@inline derivative(f!, y, x, l, p::Int64) = derivative(f!, y, x, l, Val{p}())
37-
@inline derivative!(result, f, x, l, p::Int64) = derivative!(
38-
result, f, x, l, Val{p}())
39-
@inline derivative!(result, f!, y, x, l, p::Int64) = derivative!(
40-
result, f!, y, x, l, Val{p}())
41-
42-
# Core APIs
43-
44-
# Added to help Zygote infer types
45-
@inline make_seed(x::T, l::T, ::Val{P}) where {T <: Real, P} = TaylorScalar{P}(x, l)
46-
@inline make_seed(x::A, l::A, ::Val{P}) where {A <: AbstractArray, P} = broadcast(
47-
make_seed, x, l, Val{P}())
48-
49-
# `derivative` API: computes the `P - 1`-th derivative of `f` at `x`
50-
@inline derivative(f, x, l, p::Val{P}) where {P} = extract_derivative(
51-
derivatives(f, x, l, p), p)
52-
@inline derivative(f!, y, x, l, p::Val{P}) where {P} = extract_derivative(
53-
derivatives(f!, y, x, l, p), p)
54-
@inline derivative!(result, f, x, l, p::Val{P}) where {P} = extract_derivative!(
55-
result, derivatives(f, x, l, p), p)
56-
@inline derivative!(result, f!, y, x, l, p::Val{P}) where {P} = extract_derivative!(
57-
result, derivatives(f!, y, x, l, p), p)
58-
59-
# `derivatives` API: computes all derivatives of `f` at `x` up to p `P - 1`
60-
61-
# Out-of-place function
6245
@inline derivatives(f, x, l, p::Val{P}) where {P} = f(make_seed(x, l, p))
63-
64-
# In-place function
6546
@inline function derivatives(f!, y, x, l, p::Val{P}) where {P}
6647
buffer = similar(y, TaylorScalar{eltype(y), P})
6748
f!(buffer, make_seed(x, l, p))

src/primitive.jl

+14-14
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ end
4444

4545
## Hand-written exp, sin, cos
4646

47-
@to_static function exp(t::TaylorScalar{T, P}) where {P, T}
47+
@immutable function exp(t::TaylorScalar{T, P}) where {P, T}
4848
f = flatten(t)
4949
v[0] = exp(f[0])
5050
for i in 1:P
@@ -58,7 +58,7 @@ end
5858
end
5959

6060
for func in (:sin, :cos)
61-
@eval @to_static function $func(t::TaylorScalar{T, P}) where {T, P}
61+
@eval @immutable function $func(t::TaylorScalar{T, P}) where {T, P}
6262
f = flatten(t)
6363
s[0], c[0] = sincos(f[0])
6464
for i in 1:P
@@ -104,7 +104,7 @@ end
104104
@inline -(a::TaylorScalar, b::TaylorScalar) = TaylorScalar(
105105
value(a) - value(b), map(-, partials(a), partials(b)))
106106

107-
@to_static function *(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P}
107+
@immutable function *(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P}
108108
va, vb = flatten(a), flatten(b)
109109
for i in 0:P
110110
v[i] = zero(T)
@@ -115,7 +115,7 @@ end
115115
TaylorScalar(v)
116116
end
117117

118-
@to_static function /(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P}
118+
@immutable function /(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P}
119119
va, vb = flatten(a), flatten(b)
120120
v[0] = va[0] / vb[0]
121121
for i in 1:P
@@ -130,13 +130,13 @@ end
130130

131131
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{0}) = one(x)
132132
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{1}) = x
133-
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{2}) = x*x
134-
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{3}) = x*x*x
133+
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{2}) = x * x
134+
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{3}) = x * x * x
135135
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{-1}) = inv(x)
136-
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{-2}) = (i=inv(x); i*i)
136+
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{-2}) = (i = inv(x); i * i)
137137

138138
for R in (Integer, Real)
139-
@eval @to_static function ^(t::TaylorScalar{T, P}, n::S) where {S <: $R, T, P}
139+
@eval @immutable function ^(t::TaylorScalar{T, P}, n::S) where {S <: $R, T, P}
140140
f = flatten(t)
141141
v[0] = f[0]^n
142142
for i in 1:P
@@ -153,14 +153,14 @@ end
153153

154154
^(t::TaylorScalar, s::TaylorScalar) = exp(s * log(t))
155155

156-
@inline function lower(t::TaylorScalar{T, P}) where {T, P}
156+
@inline function differentiate(t::TaylorScalar{T, P}) where {T, P}
157157
s = partials(t)
158158
TaylorScalar(ntuple(i -> s[i] * i, Val(P)))
159159
end
160-
@inline function higher(t::TaylorScalar{T, P}) where {T, P}
160+
@inline function integrate(t::TaylorScalar{T, P}, C::T) where {T, P}
161161
s = flatten(t)
162-
ntuple(i -> s[i] / i, Val(P + 1))
162+
TaylorScalar(C, ntuple(i -> s[i] / i, Val(P + 1)))
163163
end
164-
@inline raise(f, df::TaylorScalar, t) = TaylorScalar(f, higher(lower(t) * df))
165-
@inline raise(f, df::Number, t) = df * t
166-
@inline raiseinv(f, df, t) = TaylorScalar(f, higher(lower(t) / df))
164+
@inline raise(f0, d::TaylorScalar, t) = integrate(differentiate(t) * d, f0)
165+
@inline raise(f0, d::Number, t) = d * t
166+
@inline raiseinv(f0, d, t) = integrate(differentiate(t) / d, f0)

src/scalar.jl

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Convenience function: construct a Taylor polynomial with zeroth and first order
4040
TaylorScalar{P}(value::T, seed::T) where {T, P} = TaylorScalar(
4141
value, ntuple(i -> i == 1 ? seed : zero(T), Val(P)))
4242

43+
# Truncate or extend the order of a Taylor polynomial.
4344
function TaylorScalar{P}(t::TaylorScalar{T, Q}) where {T, P, Q}
4445
v = value(t)
4546
p = partials(t)

src/utils.jl

+16-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# This file is a bunch of compiler magics to cleverly define pushforward rules.
2+
# If you are only interested in data structures and pushforward rules, you can skip this file.
3+
14
using ChainRules
25
using ChainRulesCore
36
using Symbolics: @variables, @rule, unwrap, isdiv
@@ -6,7 +9,9 @@ using MacroTools
69
using MacroTools: prewalk, postwalk
710

811
"""
9-
Pick a strategy for raising the derivative of a function. If the derivative is like 1 over something, raise with the division rule; otherwise, raise with the multiplication rule.
12+
Pick a strategy for raising the derivative of a function.
13+
If the derivative is like 1 over something, raise with the division rule;
14+
otherwise, raise with the multiplication rule.
1015
"""
1116
function get_term_raiser(func)
1217
@variables z
@@ -95,7 +100,16 @@ function process(d, expr)
95100
end
96101
end
97102

98-
macro to_static(def)
103+
"""
104+
immutable(def)
105+
106+
Transform a function definition to a @generated function.
107+
108+
1. Allocations are removed by replacing the output with scalar variables;
109+
2. Loops are unrolled;
110+
3. Indices are modified to use 1-based indexing;
111+
"""
112+
macro immutable(def)
99113
dict = splitdef(def)
100114
pairs = Any[]
101115
for symbol in dict[:whereparams]

test/derivative.jl

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11

22
@testset "O-function, O-derivative" begin
33
g(x) = x^3
4-
@test derivative(g, 1.0, 1) 3
4+
@test derivative(g, 1.0, Val(1)) 3
55

66
h(x) = x .^ 3
7-
@test derivative(h, [2.0 3.0], 1) [12.0 27.0]
7+
@test derivative(h, [2.0 3.0], [1.0 1.0], Val(1)) [12.0 27.0]
88

99
g1(x) = x[1] * x[1] + x[2] * x[2]
10-
@test derivative(g1, [1.0, 2.0], [1.0, 0.0], 1) 2.0
10+
@test derivative(g1, [1.0, 2.0], [1.0, 0.0], Val(1)) 2.0
1111

1212
h1(x) = sum(x, dims = 1)
13-
@test derivative(h1, [1.0 2.0; 2.0 3.0], [1.0 1.0; 1.0 1.0], 1) [2.0 2.0]
13+
@test derivative(h1, [1.0 2.0; 2.0 3.0], [1.0 1.0; 1.0 1.0], Val(1)) [2.0 2.0]
1414
end
1515

1616
@testset "I-function, O-derivative" begin
@@ -20,12 +20,12 @@ end
2020
end
2121
x = 2.0
2222
y = [0.0, 0.0]
23-
@test derivative(g!, y, x, 1.0, Val{1}()) [4.0, 1.0]
23+
@test derivative(g!, y, x, 1.0, Val(1)) [4.0, 1.0]
2424
end
2525

2626
@testset "O-function, I-derivative" begin
2727
g(x) = x .^ 2
28-
@test derivative!(zeros(2), g, [1.0, 2.0], [1.0, 0.0], Val{1}()) [2.0, 0.0]
28+
@test derivative!(zeros(2), g, [1.0, 2.0], [1.0, 0.0], Val(1)) [2.0, 0.0]
2929
end
3030

3131
@testset "I-function, I-derivative" begin
@@ -35,5 +35,5 @@ end
3535
end
3636
x = [2.0, 3.0]
3737
y = [0.0, 0.0]
38-
@test derivative!(y, g!, zeros(2), x, [1.0, 0.0], Val{1}()) [4.0, 0.0]
38+
@test derivative!(y, g!, zeros(2), x, [1.0, 0.0], Val(1)) [4.0, 0.0]
3939
end

0 commit comments

Comments
 (0)