@@ -18,6 +18,7 @@ function DI.prepare_pushforward_nokwarg(
18
18
step_der_var = derivative (f (x_var + t_var * dx_var, context_vars... ), t_var)
19
19
pf_var = substitute (step_der_var, Dict (t_var => zero (eltype (x))))
20
20
21
+ erase_cache_vars! (context_vars, contexts)
21
22
res = build_function (
22
23
pf_var, x_var, dx_var, context_vars... ; expression= Val (false ), cse= true
23
24
)
@@ -104,6 +105,7 @@ function DI.prepare_derivative_nokwarg(
104
105
context_vars = variablize (contexts)
105
106
der_var = derivative (f (x_var, context_vars... ), x_var)
106
107
108
+ erase_cache_vars! (context_vars, contexts)
107
109
res = build_function (der_var, x_var, context_vars... ; expression= Val (false ), cse= true )
108
110
(der_exe, der_exe!) = if res isa Tuple
109
111
res
@@ -179,6 +181,7 @@ function DI.prepare_gradient_nokwarg(
179
181
# Symbolic.gradient only accepts vectors
180
182
grad_var = gradient (f (x_var, context_vars... ), vec (x_var))
181
183
184
+ erase_cache_vars! (context_vars, contexts)
182
185
res = build_function (
183
186
grad_var, vec (x_var), context_vars... ; expression= Val (false ), cse= true
184
187
)
@@ -258,6 +261,7 @@ function DI.prepare_jacobian_nokwarg(
258
261
jacobian (f (x_var, context_vars... ), x_var)
259
262
end
260
263
264
+ erase_cache_vars! (context_vars, contexts)
261
265
res = build_function (jac_var, x_var, context_vars... ; expression= Val (false ), cse= true )
262
266
(jac_exe, jac_exe!) = res
263
267
return SymbolicsOneArgJacobianPrep (_sig, jac_exe, jac_exe!)
@@ -337,6 +341,7 @@ function DI.prepare_hessian_nokwarg(
337
341
hessian (f (x_var, context_vars... ), vec (x_var))
338
342
end
339
343
344
+ erase_cache_vars! (context_vars, contexts)
340
345
res = build_function (
341
346
hess_var, vec (x_var), context_vars... ; expression= Val (false ), cse= true
342
347
)
@@ -425,6 +430,7 @@ function DI.prepare_hvp_nokwarg(
425
430
hess_var = hessian (f (x_var, context_vars... ), vec (x_var))
426
431
hvp_vec_var = hess_var * vec (dx_var)
427
432
433
+ erase_cache_vars! (context_vars, contexts)
428
434
res = build_function (
429
435
hvp_vec_var,
430
436
vec (x_var),
@@ -519,6 +525,7 @@ function DI.prepare_second_derivative_nokwarg(
519
525
der_var = derivative (f (x_var, context_vars... ), x_var)
520
526
der2_var = derivative (der_var, x_var)
521
527
528
+ erase_cache_vars! (context_vars, contexts)
522
529
res = build_function (der2_var, x_var, context_vars... ; expression= Val (false ), cse= true )
523
530
(der2_exe, der2_exe!) = if res isa Tuple
524
531
res
0 commit comments