Skip to content

Commit 8f8e5d9

Browse files
committed
Refactor SpecialFunctionsExt
1 parent c90fc69 commit 8f8e5d9

File tree

4 files changed

+36
-59
lines changed

4 files changed

+36
-59
lines changed

ext/TaylorDiffSFExt.jl

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,8 @@
11
module TaylorDiffSFExt
22
using TaylorDiff, SpecialFunctions
3-
using Symbolics: @variables
4-
using SymbolicUtils, SymbolicUtils.Code
5-
using SymbolicUtils: Pow
6-
using TaylorDiff: value, raise
7-
using ChainRules, ChainRulesCore
83

9-
dummy = (NoTangent(), 1)
10-
@variables z
11-
# logerfc, logerfcx, erfinv, gamma, digamma, trigamma
124
for func in (erf, erfc, erfcinv, erfcx, erfi)
13-
F = typeof(func)
14-
# base case
15-
@eval function (op::$F)(t::TaylorScalar{T, 2}) where {T}
16-
t0, t1 = value(t)
17-
TaylorScalar{T, 2}(frule((NoTangent(), t1), op, t0))
18-
end
19-
der = frule(dummy, func, z)[2]
20-
term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise)
21-
# recursion by raising
22-
@eval @generated function (op::$F)(t::TaylorScalar{T, N}) where {T, N}
23-
der_expr = $(QuoteNode(toexpr(term)))
24-
f = $func
25-
quote
26-
$(Expr(:meta, :inline))
27-
z = TaylorScalar{T, N - 1}(t)
28-
df = $der_expr
29-
$$raiser($f(value(t)[1]), df, t)
30-
end
31-
end
5+
TaylorDiff.define_unary_function(func, TaylorDiffSFExt)
326
end
337

348
end

src/TaylorDiff.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module TaylorDiff
22

33
include("scalar.jl")
44
include("primitive.jl")
5+
include("utils.jl")
56
include("codegen.jl")
67
include("derivative.jl")
78
include("chainrules.jl")

src/codegen.jl

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,10 @@
1-
using ChainRules
2-
using ChainRulesCore
3-
using Symbolics: @variables
4-
using SymbolicUtils, SymbolicUtils.Code
5-
using SymbolicUtils: Pow
6-
7-
func_list = (
1+
for unary_func in (
82
+, -, deg2rad, rad2deg,
93
sinh, cosh, tanh,
104
asin, acos, atan, asec, acsc, acot,
115
log, log10, log1p, log2,
126
asinh, acosh, atanh, asech, acsch,
137
acoth,
148
abs, sign)
15-
16-
dummy = (NoTangent(), 1)
17-
@variables z
18-
for func in func_list
19-
F = typeof(func)
20-
# base case
21-
@eval function (op::$F)(t::TaylorScalar{T, 2}) where {T}
22-
t0, t1 = value(t)
23-
f0, f1 = frule((NoTangent(), t1), op, t0)
24-
TaylorScalar{T, 2}(f0, zero_tangent(f0) + f1)
25-
end
26-
der = frule(dummy, func, z)[2]
27-
term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise)
28-
# recursion by raising
29-
@eval @generated function (op::$F)(t::TaylorScalar{T, N}) where {T, N}
30-
der_expr = $(QuoteNode(toexpr(term)))
31-
f = $func
32-
quote
33-
$(Expr(:meta, :inline))
34-
z = TaylorScalar{T, N - 1}(t)
35-
f0 = $f(value(t)[1])
36-
df = zero_tangent(z) + $der_expr
37-
$$raiser(f0, df, t)
38-
end
39-
end
9+
define_unary_function(unary_func, TaylorDiff)
4010
end

src/utils.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using ChainRules
2+
using ChainRulesCore
3+
using Symbolics: @variables
4+
using SymbolicUtils, SymbolicUtils.Code
5+
using SymbolicUtils: Pow
6+
7+
dummy = (NoTangent(), 1)
8+
@variables z
9+
10+
function define_unary_function(func, m)
11+
F = typeof(func)
12+
# base case
13+
@eval m function (op::$F)(t::TaylorScalar{T, 2}) where {T}
14+
t0, t1 = value(t)
15+
f0, f1 = frule((NoTangent(), t1), op, t0)
16+
TaylorScalar{T, 2}(f0, zero_tangent(f0) + f1)
17+
end
18+
der = frule(dummy, func, z)[2]
19+
term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise)
20+
# recursion by raising
21+
@eval m @generated function (op::$F)(t::TaylorScalar{T, N}) where {T, N}
22+
der_expr = $(QuoteNode(toexpr(term)))
23+
f = $func
24+
quote
25+
$(Expr(:meta, :inline))
26+
z = TaylorScalar{T, N - 1}(t)
27+
f0 = $f(value(t)[1])
28+
df = zero_tangent(z) + $der_expr
29+
$$raiser(f0, df, t)
30+
end
31+
end
32+
end

0 commit comments

Comments
 (0)