Skip to content

Commit 51c56e8

Browse files
authored
fix: erase cache contexts before building function in Symbolics (#760)
1 parent 9546d6d commit 51c56e8

File tree

6 files changed

+25
-4
lines changed

6 files changed

+25
-4
lines changed

DifferentiationInterface/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.49"
4+
version = "0.6.50"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/docs/src/explanation/backends.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ Moreover, each context type is supported by a specific subset of backends:
7070
| `AutoMooncake` |||
7171
| `AutoPolyesterForwardDiff` |||
7272
| `AutoReverseDiff` |||
73-
| `AutoSymbolics` || |
73+
| `AutoSymbolics` || |
7474
| `AutoTracker` |||
7575
| `AutoZygote` || 🔀 |
7676

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl

+13-1
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,23 @@ variablize(::Number, name::Symbol) = variable(name)
2828
variablize(x::AbstractArray, name::Symbol) = variables(name, axes(x)...)
2929

3030
function variablize(contexts::NTuple{C,DI.Context}) where {C}
31-
map(enumerate(contexts)) do (k, c)
31+
return ntuple(Val(C)) do k
32+
c = contexts[k]
3233
variablize(DI.unwrap(c), Symbol("context$k"))
3334
end
3435
end
3536

37+
function erase_cache_vars!(
38+
context_vars::NTuple{C}, contexts::NTuple{C,DI.Context}
39+
) where {C}
40+
# erase the active data from caches before building function
41+
for (v, c) in zip(context_vars, contexts)
42+
if c isa DI.Cache
43+
fill!(v, zero(eltype(v)))
44+
end
45+
end
46+
end
47+
3648
include("onearg.jl")
3749
include("twoarg.jl")
3850

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl

+7
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ function DI.prepare_pushforward_nokwarg(
1818
step_der_var = derivative(f(x_var + t_var * dx_var, context_vars...), t_var)
1919
pf_var = substitute(step_der_var, Dict(t_var => zero(eltype(x))))
2020

21+
erase_cache_vars!(context_vars, contexts)
2122
res = build_function(
2223
pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true
2324
)
@@ -104,6 +105,7 @@ function DI.prepare_derivative_nokwarg(
104105
context_vars = variablize(contexts)
105106
der_var = derivative(f(x_var, context_vars...), x_var)
106107

108+
erase_cache_vars!(context_vars, contexts)
107109
res = build_function(der_var, x_var, context_vars...; expression=Val(false), cse=true)
108110
(der_exe, der_exe!) = if res isa Tuple
109111
res
@@ -179,6 +181,7 @@ function DI.prepare_gradient_nokwarg(
179181
# Symbolic.gradient only accepts vectors
180182
grad_var = gradient(f(x_var, context_vars...), vec(x_var))
181183

184+
erase_cache_vars!(context_vars, contexts)
182185
res = build_function(
183186
grad_var, vec(x_var), context_vars...; expression=Val(false), cse=true
184187
)
@@ -258,6 +261,7 @@ function DI.prepare_jacobian_nokwarg(
258261
jacobian(f(x_var, context_vars...), x_var)
259262
end
260263

264+
erase_cache_vars!(context_vars, contexts)
261265
res = build_function(jac_var, x_var, context_vars...; expression=Val(false), cse=true)
262266
(jac_exe, jac_exe!) = res
263267
return SymbolicsOneArgJacobianPrep(_sig, jac_exe, jac_exe!)
@@ -337,6 +341,7 @@ function DI.prepare_hessian_nokwarg(
337341
hessian(f(x_var, context_vars...), vec(x_var))
338342
end
339343

344+
erase_cache_vars!(context_vars, contexts)
340345
res = build_function(
341346
hess_var, vec(x_var), context_vars...; expression=Val(false), cse=true
342347
)
@@ -425,6 +430,7 @@ function DI.prepare_hvp_nokwarg(
425430
hess_var = hessian(f(x_var, context_vars...), vec(x_var))
426431
hvp_vec_var = hess_var * vec(dx_var)
427432

433+
erase_cache_vars!(context_vars, contexts)
428434
res = build_function(
429435
hvp_vec_var,
430436
vec(x_var),
@@ -519,6 +525,7 @@ function DI.prepare_second_derivative_nokwarg(
519525
der_var = derivative(f(x_var, context_vars...), x_var)
520526
der2_var = derivative(der_var, x_var)
521527

528+
erase_cache_vars!(context_vars, contexts)
522529
res = build_function(der2_var, x_var, context_vars...; expression=Val(false), cse=true)
523530
(der2_exe, der2_exe!) = if res isa Tuple
524531
res

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl

+3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ function DI.prepare_pushforward_nokwarg(
2626
step_der_var = derivative(y_var, t_var)
2727
pf_var = substitute(step_der_var, Dict(t_var => zero(eltype(x))))
2828

29+
erase_cache_vars!(context_vars, contexts)
2930
res = build_function(
3031
pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true
3132
)
@@ -116,6 +117,7 @@ function DI.prepare_derivative_nokwarg(
116117
f!(y_var, x_var, context_vars...)
117118
der_var = derivative(y_var, x_var)
118119

120+
erase_cache_vars!(context_vars, contexts)
119121
res = build_function(der_var, x_var, context_vars...; expression=Val(false), cse=true)
120122
(der_exe, der_exe!) = res
121123
return SymbolicsTwoArgDerivativePrep(_sig, der_exe, der_exe!)
@@ -203,6 +205,7 @@ function DI.prepare_jacobian_nokwarg(
203205
jacobian(y_var, x_var)
204206
end
205207

208+
erase_cache_vars!(context_vars, contexts)
206209
res = build_function(jac_var, x_var, context_vars...; expression=Val(false), cse=true)
207210
(jac_exe, jac_exe!) = res
208211
return SymbolicsTwoArgJacobianPrep(_sig, jac_exe, jac_exe!)

DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl

-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ test_differentiation(
2222
test_differentiation(
2323
AutoSymbolics(),
2424
default_scenarios(; include_normal=false, include_cachified=true, use_tuples=false);
25-
excluded=[:jacobian], # TODO: figure out why this fails
2625
logging=LOGGING,
2726
);
2827

0 commit comments

Comments
 (0)