Skip to content

Commit 9e4b171

Browse files
committed
Fix ChainRules codegen problem
1 parent 2cf2659 commit 9e4b171

File tree

5 files changed

+79
-9
lines changed

5 files changed

+79
-9
lines changed

examples/ode.jl

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
using ADTypes
2+
using DifferentiationInterface
3+
using ModelingToolkit, DifferentialEquations
4+
using TaylorDiff, ForwardDiff
5+
using Enzyme, Zygote, ReverseDiff
6+
using SciMLSensitivity
7+
8+
@parameters a
9+
@variables t x1(t)
10+
D = Differential(t)
11+
states = [x1]
12+
parameters = [a]
13+
14+
@named pre_model = ODESystem([D(x1) ~ a * x1], t, states, parameters)
15+
model = structural_simplify(pre_model)
16+
17+
ic = Dict(x1 => 1.0)
18+
p_true = Dict(a => 2.0)
19+
20+
problem = ODEProblem{true, SciMLBase.FullSpecialize}(model, ic, [0.0, 1.0], p_true)
21+
soln = ModelingToolkit.solve(problem, Tsit5(), abstol = 1e-12, reltol = 1e-12)
22+
display(soln(0.5, idxs = [x1]))
23+
24+
function different_time(new_ic, new_params, new_t)
25+
#newprob = ODEProblem{true, SciMLBase.FullSpecialize}(model, new_ic, [0.0, new_t*2], new_params)
26+
#newprob = remake(problem, u0=new_ic, tspan = [0.0, new_t], p = new_params)
27+
newprob = remake(problem, u0 = new_ic, tspan = [0.0, new_t], p = new_params)
28+
newprob = remake(newprob, u0 = typeof(new_t).(newprob.u0))
29+
new_soln = ModelingToolkit.solve(newprob, Tsit5(), abstol = 1e-12, reltol = 1e-12)
30+
return (soln(new_t, idxs = [x1]))
31+
end
32+
33+
function just_t(new_t)
34+
return different_time(ic, p_true, new_t)[1]
35+
end
36+
display(different_time(ic, p_true, 2e-5))
37+
display(just_t(0.5))
38+
39+
#display(ForwardDiff.derivative(just_t,1.0))
40+
display(TaylorDiff.derivative(just_t, 1.0, 1)) #isnan error
41+
#display(value_and_gradient(just_t, AutoForwardDiff(), 1.0))
42+
#display(value_and_gradient(just_t, AutoReverseDiff(), 1.0))
43+
#display(value_and_gradient(just_t, AutoEnzyme(Enzyme.Reverse), 1.0))
44+
#display(value_and_gradient(just_t, AutoEnzyme(Enzyme.Forward), 1.0))
45+
#display(value_and_gradient(just_t, AutoZygote(), 1.0))

src/codegen.jl

+11-6
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,24 @@ using Symbolics: @variables
44
using SymbolicUtils, SymbolicUtils.Code
55
using SymbolicUtils: Pow
66

7-
dummy = (NoTangent(), 1)
8-
@variables z
9-
for func in (+, -, deg2rad, rad2deg,
7+
func_list = (
8+
+, -, deg2rad, rad2deg,
109
sinh, cosh, tanh,
1110
asin, acos, atan, asec, acsc, acot,
1211
log, log10, log1p, log2,
1312
asinh, acosh, atanh, asech, acsch,
1413
acoth,
1514
abs, sign)
15+
16+
dummy = (NoTangent(), 1)
17+
@variables z
18+
for func in func_list
1619
F = typeof(func)
1720
# base case
1821
@eval function (op::$F)(t::TaylorScalar{T, 2}) where {T}
1922
t0, t1 = value(t)
20-
TaylorScalar{T, 2}(frule((NoTangent(), t1), op, t0))
23+
f0, f1 = frule((NoTangent(), t1), op, t0)
24+
TaylorScalar{T, 2}(f0, zero_tangent(f0) + f1)
2125
end
2226
der = frule(dummy, func, z)[2]
2327
term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise)
@@ -28,8 +32,9 @@ for func in (+, -, deg2rad, rad2deg,
2832
quote
2933
$(Expr(:meta, :inline))
3034
z = TaylorScalar{T, N - 1}(t)
31-
df = $der_expr
32-
$$raiser($f(value(t)[1]), df, t)
35+
f0 = $f(value(t)[1])
36+
df = zero_tangent(z) + $der_expr
37+
$$raiser(f0, df, t)
3338
end
3439
end
3540
end

src/primitive.jl

+5
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,13 @@ for R in (Integer, Real)
151151
ex = :($ex; TaylorScalar($([Symbol('u', i) for i in 1:N]...)))
152152
return :(@inbounds $ex)
153153
end
154+
@eval function ^(a::S, t::TaylorScalar{T, N}) where {S <: $R, T, N}
155+
exp(t * log(a))
156+
end
154157
end
155158

159+
^(t::TaylorScalar, s::TaylorScalar) = exp(s * log(t))
160+
156161
@generated function raise(f::T, df::TaylorScalar{T, M},
157162
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
158163
return quote

src/scalar.jl

+17-2
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,21 @@ function promote_rule(::Type{TaylorScalar{T, N}},
9090
TaylorScalar{promote_type(T, S), N}
9191
end
9292

93-
function Base.AbstractFloat(x::TaylorScalar{T, N}) where {T, N}
94-
TaylorScalar{Float64, N}(convert(NTuple{N, Float64}, x.value))
93+
function (::Type{F})(x::TaylorScalar{T, N}) where {T, N, F <: AbstractFloat}
94+
F(primal(x))
95+
end
96+
97+
function Base.nextfloat(x::TaylorScalar{T, N}) where {T, N}
98+
TaylorScalar{T, N}(ntuple(i -> i == 1 ? nextfloat(value(x)[i]) : value(x)[i], N))
99+
end
100+
101+
function Base.prevfloat(x::TaylorScalar{T, N}) where {T, N}
102+
TaylorScalar{T, N}(ntuple(i -> i == 1 ? prevfloat(value(x)[i]) : value(x)[i], N))
103+
end
104+
105+
const UNARY_PREDICATES = Symbol[
106+
:isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger]
107+
108+
for pred in UNARY_PREDICATES
109+
@eval Base.$(pred)(x::TaylorScalar) = $(pred)(primal(x))
95110
end

test/primitive.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using FiniteDifferences
22

33
@testset "No derivative or linear" begin
44
some_number, another_number = 1.9, 2.6
5-
for f in (+, -, zero, one, adjoint, conj, deg2rad, rad2deg), order in (2,)
5+
for f in (+, -, zero, one, adjoint, conj, deg2rad, rad2deg, abs, sign), order in (2,)
66
@test derivative(f, some_number, order) 0.0
77
end
88
for f in (+, -, <, <=, >, >=, ==), order in (2,)

0 commit comments

Comments
 (0)