Skip to content

Commit 18934af

Browse files
committed
Use macro for shared caches
1 parent dccc1dd commit 18934af

File tree

8 files changed

+47
-59
lines changed

8 files changed

+47
-59
lines changed

src/algorithms/multistep.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
function MultiStepNonlinearSolver(; concrete_jac = nothing, linsolve = nothing,
22
scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing,
33
vjp_autodiff = nothing, linesearch = NoLineSearch())
4-
scheme_concrete = apply_patch(scheme, (; autodiff, vjp_autodiff))
4+
forward_ad = ifelse(autodiff isa ADTypes.AbstractForwardMode, autodiff, nothing)
5+
scheme_concrete = apply_patch(
6+
scheme, (; autodiff, vjp_autodiff, jvp_autodiff = forward_ad))
57
descent = GenericMultiStepDescent(; scheme = scheme_concrete, linsolve, precs)
68
return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = MSS.display_name(scheme),
7-
descent, jacobian_ad = autodiff, linesearch, reverse_ad = vjp_autodiff)
9+
descent, jacobian_ad = autodiff, linesearch, reverse_ad = vjp_autodiff, forward_ad)
810
end

src/descent/damped_newton.jl

+1-5
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,7 @@ function __internal_init(
5858
shared::Val{N} = Val(1), kwargs...) where {INV, N}
5959
length(fu) != length(u) &&
6060
@assert !INV "Precomputed Inverse for Non-Square Jacobian doesn't make sense."
61-
@bb δu = similar(u)
62-
δus = N 1 ? nothing : map(2:N) do i
63-
@bb δu_ = similar(u)
64-
end
65-
61+
δu, δus = @shared_caches N (@bb δu = similar(u))
6662
normal_form_damping = returns_norm_form_damping(alg.damping_fn)
6763
normal_form_linsolve = __needs_square_A(alg.linsolve, u)
6864
if u isa Number

src/descent/dogleg.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,7 @@ function __internal_init(prob::AbstractNonlinearProblem, alg::Dogleg, J, fu, u;
5656
linsolve_kwargs, abstol, reltol, shared, kwargs...)
5757
cauchy_cache = __internal_init(prob, alg.steepest_descent, J, fu, u; pre_inverted,
5858
linsolve_kwargs, abstol, reltol, shared, kwargs...)
59-
@bb δu = similar(u)
60-
δus = N 1 ? nothing : map(2:N) do i
61-
@bb δu_ = similar(u)
62-
end
59+
δu, δus = @shared_caches N (@bb δu = similar(u))
6360
@bb δu_cache_1 = similar(u)
6461
@bb δu_cache_2 = similar(u)
6562
@bb δu_cache_mul = similar(u)

src/descent/geodesic_acceleration.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,7 @@ function __internal_init(prob::AbstractNonlinearProblem, alg::GeodesicAccelerati
8989
abstol = nothing, reltol = nothing, internalnorm::F = DEFAULT_NORM,
9090
kwargs...) where {INV, N, F}
9191
T = promote_type(eltype(u), eltype(fu))
92-
@bb δu = similar(u)
93-
δus = N 1 ? nothing : map(2:N) do i
94-
@bb δu_ = similar(u)
95-
end
92+
δu, δus = @shared_caches N (@bb δu = similar(u))
9693
descent_cache = __internal_init(prob, alg.descent, J, fu, u; shared = Val(N * 2),
9794
pre_inverted, linsolve_kwargs, abstol, reltol, kwargs...)
9895
@bb Jv = similar(fu)

src/descent/multistep.jl

+24-32
Original file line numberDiff line numberDiff line change
@@ -15,34 +15,36 @@ function Base.show(io::IO, mss::AbstractMultiStepScheme)
1515
print(io, "MultiStepSchemes.$(string(nameof(typeof(mss)))[3:end])")
1616
end
1717

18-
alg_steps(::Type{T}) where {T <: AbstractMultiStepScheme} = alg_steps(T())
18+
newton_steps(::Type{T}) where {T <: AbstractMultiStepScheme} = newton_steps(T())
1919

2020
struct __PotraPtak3 <: AbstractMultiStepScheme end
2121
const PotraPtak3 = __PotraPtak3()
2222

23-
alg_steps(::__PotraPtak3) = 2
23+
newton_steps(::__PotraPtak3) = 2
2424
nintermediates(::__PotraPtak3) = 1
2525

2626
@kwdef @concrete struct __SinghSharma4 <: AbstractMultiStepScheme
2727
jvp_autodiff = nothing
2828
end
2929
const SinghSharma4 = __SinghSharma4()
3030

31-
alg_steps(::__SinghSharma4) = 3
31+
newton_steps(::__SinghSharma4) = 4
32+
nintermediates(::__SinghSharma4) = 2
3233

3334
@kwdef @concrete struct __SinghSharma5 <: AbstractMultiStepScheme
3435
jvp_autodiff = nothing
3536
end
3637
const SinghSharma5 = __SinghSharma5()
3738

38-
alg_steps(::__SinghSharma5) = 3
39+
newton_steps(::__SinghSharma5) = 4
40+
nintermediates(::__SinghSharma5) = 2
3941

4042
@kwdef @concrete struct __SinghSharma7 <: AbstractMultiStepScheme
4143
jvp_autodiff = nothing
4244
end
4345
const SinghSharma7 = __SinghSharma7()
4446

45-
alg_steps(::__SinghSharma7) = 4
47+
newton_steps(::__SinghSharma7) = 6
4648

4749
@generated function display_name(alg::T) where {T <: AbstractMultiStepScheme}
4850
res = Symbol(first(split(last(split(string(T), ".")), "{"; limit = 2))[3:end])
@@ -75,6 +77,8 @@ supports_trust_region(::GenericMultiStepDescent) = false
7577
fus
7678
internal_cache
7779
internal_caches
80+
extra
81+
extras
7882
scheme::S
7983
timer
8084
nf::Int
@@ -91,49 +95,37 @@ function __reinit_internal!(cache::GenericMultiStepDescentCache, args...; p = ca
9195
end
9296

9397
function __internal_multistep_caches(
94-
scheme::MSS.__PotraPtak3, alg::GenericMultiStepDescent,
95-
prob, args...; shared::Val{N} = Val(1), kwargs...) where {N}
98+
scheme::Union{MSS.__PotraPtak3, MSS.__SinghSharma4, MSS.__SinghSharma5},
99+
alg::GenericMultiStepDescent, prob, args...;
100+
shared::Val{N} = Val(1), kwargs...) where {N}
96101
internal_descent = NewtonDescent(; alg.linsolve, alg.precs)
97-
internal_cache = __internal_init(
102+
return @shared_caches N __internal_init(
98103
prob, internal_descent, args...; kwargs..., shared = Val(2))
99-
internal_caches = N 1 ? nothing :
100-
map(2:N) do i
101-
__internal_init(prob, internal_descent, args...; kwargs..., shared = Val(2))
102-
end
103-
return internal_cache, internal_caches
104104
end
105105

106+
__extras_cache(::MSS.AbstractMultiStepScheme, args...; kwargs...) = nothing, nothing
107+
106108
function __internal_init(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
107109
alg::GenericMultiStepDescent, J, fu, u; shared::Val{N} = Val(1),
108110
pre_inverted::Val{INV} = False, linsolve_kwargs = (;),
109111
abstol = nothing, reltol = nothing, timer = get_timer_output(),
110112
kwargs...) where {INV, N}
111-
@bb δu = similar(u)
112-
δus = N 1 ? nothing : map(2:N) do i
113-
@bb δu_ = similar(u)
114-
end
115-
fu_cache = ntuple(MSS.nintermediates(alg.scheme)) do i
113+
δu, δus = @shared_caches N (@bb δu = similar(u))
114+
fu_cache, fus_cache = @shared_caches N (ntuple(MSS.nintermediates(alg.scheme)) do i
116115
@bb xx = similar(fu)
117-
end
118-
fus_cache = N 1 ? nothing : map(2:N) do i
119-
ntuple(MSS.nintermediates(alg.scheme)) do j
120-
@bb xx = similar(fu)
121-
end
122-
end
123-
u_cache = ntuple(MSS.nintermediates(alg.scheme)) do i
116+
end)
117+
u_cache, us_cache = @shared_caches N (ntuple(MSS.nintermediates(alg.scheme)) do i
124118
@bb xx = similar(u)
125-
end
126-
us_cache = N 1 ? nothing : map(2:N) do i
127-
ntuple(MSS.nintermediates(alg.scheme)) do j
128-
@bb xx = similar(u)
129-
end
130-
end
119+
end)
131120
internal_cache, internal_caches = __internal_multistep_caches(
132121
alg.scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs,
133122
abstol, reltol, timer, kwargs...)
123+
extra, extras = __extras_cache(
124+
alg.scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs,
125+
abstol, reltol, timer, kwargs...)
134126
return GenericMultiStepDescentCache(
135127
prob.f, prob.p, δu, δus, u_cache, us_cache, fu_cache, fus_cache,
136-
internal_cache, internal_caches, alg.scheme, timer, 0)
128+
internal_cache, internal_caches, extra, extras, alg.scheme, timer, 0)
137129
end
138130

139131
function __internal_solve!(cache::GenericMultiStepDescentCache{MSS.__PotraPtak3, INV}, J,

src/descent/newton.jl

+2-8
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,7 @@ function __internal_init(prob::NonlinearProblem, alg::NewtonDescent, J, fu, u;
3636
shared::Val{N} = Val(1), pre_inverted::Val{INV} = False, linsolve_kwargs = (;),
3737
abstol = nothing, reltol = nothing, timer = get_timer_output(),
3838
kwargs...) where {INV, N}
39-
@bb δu = similar(u)
40-
δus = N 1 ? nothing : map(2:N) do i
41-
@bb δu_ = similar(u)
42-
end
39+
δu, δus = @shared_caches N (@bb δu = similar(u))
4340
INV && return NewtonDescentCache{true, false}(δu, δus, nothing, nothing, nothing, timer)
4441
lincache = LinearSolverCache(alg, alg.linsolve, J, _vec(fu), _vec(u); abstol, reltol,
4542
linsolve_kwargs...)
@@ -64,10 +61,7 @@ function __internal_init(prob::NonlinearLeastSquaresProblem, alg::NewtonDescent,
6461
end
6562
lincache = LinearSolverCache(alg, alg.linsolve, A, b, _vec(u); abstol, reltol,
6663
linsolve_kwargs...)
67-
@bb δu = similar(u)
68-
δus = N 1 ? nothing : map(2:N) do i
69-
@bb δu_ = similar(u)
70-
end
64+
δu, δus = @shared_caches N (@bb δu = similar(u))
7165
return NewtonDescentCache{false, normal_form}(δu, δus, lincache, JᵀJ, Jᵀfu, timer)
7266
end
7367

src/descent/steepest.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,7 @@ end
3434
linsolve_kwargs = (;), abstol = nothing, reltol = nothing,
3535
timer = get_timer_output(), kwargs...) where {INV, N}
3636
INV && @assert length(fu)==length(u) "Non-Square Jacobian Inverse doesn't make sense."
37-
@bb δu = similar(u)
38-
δus = N 1 ? nothing : map(2:N) do i
39-
@bb δu_ = similar(u)
40-
end
37+
δu, δus = @shared_caches N (@bb δu = similar(u))
4138
if INV
4239
lincache = LinearSolverCache(alg, alg.linsolve, transpose(J), _vec(fu), _vec(u);
4340
abstol, reltol, linsolve_kwargs...)

src/utils.jl

+13
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,16 @@ present in the scheme, they are ignored.
177177
push!(exprs, :(return scheme))
178178
return Expr(:block, exprs...)
179179
end
180+
181+
macro shared_caches(N, expr)
182+
@gensym cache caches
183+
return esc(quote
184+
begin
185+
$(cache) = $(expr)
186+
$(caches) = $(N) 1 ? nothing : map(2:($(N))) do i
187+
$(expr)
188+
end
189+
($cache, $caches)
190+
end
191+
end)
192+
end

0 commit comments

Comments
 (0)