Skip to content

Commit 6ad1d31

Browse files
format
1 parent 1b34f0f commit 6ad1d31

File tree

4 files changed

+291
-182
lines changed

4 files changed

+291
-182
lines changed

docs/make.jl

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,18 @@
11
using L2OALM
22
using Documenter
33

4-
DocMeta.setdocmeta!(L2OALM, :DocTestSetup, :(using L2OALM); recursive=true)
4+
DocMeta.setdocmeta!(L2OALM, :DocTestSetup, :(using L2OALM); recursive = true)
55

66
makedocs(;
7-
modules=[L2OALM],
8-
authors="Andrew <[email protected]> and contributors",
9-
sitename="L2OALM.jl",
10-
format=Documenter.HTML(;
11-
canonical="https://LearningToOptimize.github.io/L2OALM.jl",
12-
edit_link="main",
13-
assets=String[],
7+
modules = [L2OALM],
8+
authors = "Andrew <[email protected]> and contributors",
9+
sitename = "L2OALM.jl",
10+
format = Documenter.HTML(;
11+
canonical = "https://LearningToOptimize.github.io/L2OALM.jl",
12+
edit_link = "main",
13+
assets = String[],
1414
),
15-
pages=[
16-
"Home" => "index.md",
17-
],
15+
pages = ["Home" => "index.md"],
1816
)
1917

20-
deploydocs(;
21-
repo="github.com/LearningToOptimize/L2OALM.jl",
22-
devbranch="main",
23-
)
18+
deploydocs(; repo = "github.com/LearningToOptimize/L2OALM.jl", devbranch = "main")

src/L2OALM.jl

Lines changed: 126 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ using Lux.Training
99
using CUDA
1010
using Statistics
1111

12-
export LagrangianDualLoss, LagrangianPrimalLoss, TrainingStepLoop,
13-
L2OALM_epoch!, L2OALM_train!
12+
export LagrangianDualLoss,
13+
LagrangianPrimalLoss, TrainingStepLoop, L2OALM_epoch!, L2OALM_train!
1414

1515
"""
1616
LagrangianDualLoss(;max_dual=1e6)
@@ -23,27 +23,27 @@ Target dual variables are clipped from zero to `max_dual`.
2323
Keywords:
2424
- `max_dual`: Maximum value for the target dual variables.
2525
"""
26-
function LagrangianDualLoss(num_equal::Int; max_dual=1e6)
26+
function LagrangianDualLoss(num_equal::Int; max_dual = 1e6)
2727
return (dual_model, ps_dual, st_dual, data) -> begin
2828
x, hpm, dual_hat_k, gh = data
2929
ρ = hpm[]
3030
# Get current dual predictions
3131
dual_hat, st_dual_new = dual_model(x, ps_dual, st_dual)
32-
32+
3333
# Separate bound and equality constraints
34-
gh_bound = gh[1:end-num_equal,:]
35-
gh_equal = gh[end-num_equal+1:end,:]
36-
dual_hat_bound = dual_hat_k[1:end-num_equal,:]
37-
dual_hat_equal = dual_hat_k[end-num_equal+1:end,:]
38-
34+
gh_bound = gh[1:end-num_equal, :]
35+
gh_equal = gh[end-num_equal+1:end, :]
36+
dual_hat_bound = dual_hat_k[1:end-num_equal, :]
37+
dual_hat_equal = dual_hat_k[end-num_equal+1:end, :]
38+
3939
# Target for dual variables
4040
dual_target = vcat(
4141
min.(max.(dual_hat_bound + ρ .* gh_bound, 0), max_dual),
42-
min.(dual_hat_equal + ρ .* gh_equal, max_dual)
42+
min.(dual_hat_equal + ρ .* gh_equal, max_dual),
4343
)
44-
45-
loss = mean((dual_hat .- dual_target).^2)
46-
return loss, st_dual_new, (dual_loss=loss,)
44+
45+
loss = mean((dual_hat .- dual_target) .^ 2)
46+
return loss, st_dual_new, (dual_loss = loss,)
4747
end
4848
end
4949

@@ -56,31 +56,31 @@ from current dual predictions `dual_hat` for the batch model `bm` under paramete
5656
Arguments:
5757
- `bm`: A `BatchModel` instance that contains the model and batch configuration.
5858
"""
59-
function LagrangianPrimalLoss(bm::BatchModel)
59+
function LagrangianPrimalLoss(bm::BatchModel)
6060
return (model, ps, st, data) -> begin
6161
Θ, hpm, dual_hat = data
6262
ρ = hpm[]
6363
num_s = size(Θ, 2)
6464

6565
# Forward pass for prediction
6666
X̂, st_new = model(Θ, ps, st)
67-
67+
6868
# Calculate violations and objectives
6969
objs = BNK.objective!(bm, X̂, Θ)
7070
# gh = constraints!(bm, X̂, Θ)
7171
Vc, Vb = BNK.all_violations!(bm, X̂, Θ)
7272
V = vcat(Vb, Vc)
7373
total_loss = (
74-
sum(abs.(dual_hat .* V)) / num_s +
75-
ρ / 2 * sum((V).^2) / num_s +
76-
mean(objs)
74+
sum(abs.(dual_hat .* V)) / num_s + ρ / 2 * sum((V) .^ 2) / num_s + mean(objs)
7775
)
7876

79-
return total_loss, st_new, (
80-
total_loss=total_loss,
81-
mean_violations=mean(V),
82-
new_max_violation=maximum(V),
83-
mean_objs=mean(objs),
77+
return total_loss,
78+
st_new,
79+
(
80+
total_loss = total_loss,
81+
mean_violations = mean(V),
82+
new_max_violation = maximum(V),
83+
mean_objs = mean(objs),
8484
)
8585
end
8686
end
@@ -100,7 +100,7 @@ Fields:
100100
mutable struct TrainingStepLoop
101101
loss_fn::Function
102102
stopping_criteria::Vector{Function}
103-
hyperparameters::Dict{Symbol, Any}
103+
hyperparameters::Dict{Symbol,Any}
104104
parameter_update_fns::Vector{Function}
105105
reconcile_state::Function
106106
pre_hook::Function
@@ -112,7 +112,14 @@ end
112112
Default pre-hook function for the primal model in the L2O-ALM algorithm.
113113
This function performs a forward pass through the dual model to obtain the dual predictions.
114114
"""
115-
function _pre_hook_primal(θ, primal_model, train_state_primal, dual_model, train_state_dual, bm)
115+
function _pre_hook_primal(
116+
θ,
117+
primal_model,
118+
train_state_primal,
119+
dual_model,
120+
train_state_dual,
121+
bm,
122+
)
116123
# Forward pass for dual model
117124
dual_hat_k, _ = dual_model(θ, train_state_dual.parameters, train_state_dual.states)
118125

@@ -125,15 +132,22 @@ end
125132
Default pre-hook function for the dual model in the L2O-ALM algorithm.
126133
This function performs a forward pass through the primal model to obtain the predicted state and constraints.
127134
"""
128-
function _pre_hook_dual(θ, primal_model, train_state_primal, dual_model, train_state_dual, bm)
135+
function _pre_hook_dual(
136+
θ,
137+
primal_model,
138+
train_state_primal,
139+
dual_model,
140+
train_state_dual,
141+
bm,
142+
)
129143
# # Forward pass for primal model
130144
X̂, _ = primal_model(θ, train_state_primal.parameters, train_state_primal.states)
131145
gh = constraints!(bm, X̂, Θ)
132-
146+
133147
# Forward pass for dual model
134148
dual_hat, _ = dual_model(θ, train_state_dual.parameters, train_state_dual.states)
135149

136-
return (dual_hat, gh,)
150+
return (dual_hat, gh)
137151
end
138152

139153
"""
@@ -149,10 +163,10 @@ function _reconcile_alm_primal_state(batch_states::Vector{NamedTuple})
149163
mean_objs = mean([s.mean_objs for s in batch_states])
150164
mean_loss = mean([s.total_loss for s in batch_states])
151165
return (;
152-
new_max_violation=max_violation,
153-
mean_violations=mean_violations,
154-
mean_objs=mean_objs,
155-
total_loss=mean_loss,
166+
new_max_violation = max_violation,
167+
mean_violations = mean_violations,
168+
mean_objs = mean_objs,
169+
total_loss = mean_loss,
156170
)
157171
end
158172

@@ -164,7 +178,7 @@ This function computes the mean dual loss from the batch states.
164178
"""
165179
function _reconcile_alm_dual_state(batch_states::Vector{NamedTuple})
166180
dual_loss = mean([s.dual_loss for s in batch_states])
167-
return (dual_loss=dual_loss,)
181+
return (dual_loss = dual_loss,)
168182
end
169183

170184
"""
@@ -173,7 +187,11 @@ end
173187
Default function to update the hyperparameter ρ in the ALM algorithm.
174188
This function increases ρ by a factor of α if the new maximum violation exceeds τ times the previous maximum violation.
175189
"""
176-
function _update_ALM_ρ!(hpm_primal::Dict{Symbol, Any}, hpm_dual::Dict{Symbol, Any}, current_state::NamedTuple)
190+
function _update_ALM_ρ!(
191+
hpm_primal::Dict{Symbol,Any},
192+
hpm_dual::Dict{Symbol,Any},
193+
current_state::NamedTuple,
194+
)
177195
if current_state.new_max_violation > hpm_primal.τ * hpm_primal.max_violation
178196
hpm_primal[] = min(hpm_primal[:ρmax], hpm_primal[] * hpm_primal[])
179197
hpm_dual[] = hpm_primal[] # Ensure dual model uses the same ρ
@@ -191,7 +209,7 @@ function _default_primal_loop(bm::BatchModel)
191209
return TrainingStepLoop(
192210
LagrangianPrimalLoss(bm),
193211
[(iter, current_state, hpm) -> iter >= 100 ? true : false],
194-
Dict{Symbol, Any}(
212+
Dict{Symbol,Any}(
195213
=> 1.0,
196214
:ρmax => 1e6,
197215
=> 0.8,
@@ -200,7 +218,7 @@ function _default_primal_loop(bm::BatchModel)
200218
),
201219
[_update_ALM_ρ!],
202220
_reconcile_alm_primal_state,
203-
_pre_hook_primal
221+
_pre_hook_primal,
204222
)
205223
end
206224

@@ -213,13 +231,10 @@ function _default_dual_loop(num_equal::Int)
213231
return TrainingStepLoop(
214232
LagrangianDualLoss(num_equal),
215233
[(iter, current_state, hpm) -> iter >= 100 ? true : false],
216-
Dict{Symbol, Any}(
217-
:max_dual => 1e6,
218-
=> 1.0,
219-
),
234+
Dict{Symbol,Any}(:max_dual => 1e6, => 1.0),
220235
[],
221236
_reconcile_alm_dual_state,
222-
_pre_hook_dual
237+
_pre_hook_dual,
223238
)
224239
end
225240

@@ -249,7 +264,7 @@ function L2OALM_epoch!(
249264
train_state_dual::Lux.Training.TrainState,
250265
training_step_loop_primal::TrainingStepLoop,
251266
training_step_loop_dual::TrainingStepLoop,
252-
data
267+
data,
253268
)
254269
iter_primal = 1
255270
iter_dual = 1
@@ -258,36 +273,73 @@ function L2OALM_epoch!(
258273
current_state_dual = (;)
259274

260275
# primal loop
261-
while all(stopping_criterion(iter_primal, current_state_primal, training_step_loop_primal.hyperparameters) for stopping_criterion in training_step_loop_primal.stopping_criteria)
276+
while all(
277+
stopping_criterion(
278+
iter_primal,
279+
current_state_primal,
280+
training_step_loop_primal.hyperparameters,
281+
) for stopping_criterion in training_step_loop_primal.stopping_criteria
282+
)
262283
current_states_primal = Vector{NamedTuple}(undef, num_batches)
263284
iter_batch = 1
264285
for (θ) in data
265286
_, loss_val, stats, train_state_primal = Training.single_train_step!(
266287
AutoZygote(), # AD backend
267288
training_step_loop_primal.loss_fn, # Loss function
268-
(θ, training_step_loop_primal.hyperparameters, training_step_loop_primal.pre_hook(θ, primal_model, train_state_primal, dual_model, train_state_dual)...), # Data
269-
train_state_primal # Training state
289+
(
290+
θ,
291+
training_step_loop_primal.hyperparameters,
292+
training_step_loop_primal.pre_hook(
293+
θ,
294+
primal_model,
295+
train_state_primal,
296+
dual_model,
297+
train_state_dual,
298+
)...,
299+
), # Data
300+
train_state_primal, # Training state
270301
)
271302
current_states_primal[iter_batch] = stats
272303
iter_batch += 1
273304
end
274-
current_state_primal = training_step_loop_primal.reconcile_state(current_states_primal)
305+
current_state_primal =
306+
training_step_loop_primal.reconcile_state(current_states_primal)
275307
iter_primal += 1
276308
end
277309
for fn in training_step_loop_primal.parameter_update_fns
278-
fn(training_step_loop_primal.hyperparameters, training_step_loop_dual.hyperparameters, current_state_primal)
310+
fn(
311+
training_step_loop_primal.hyperparameters,
312+
training_step_loop_dual.hyperparameters,
313+
current_state_primal,
314+
)
279315
end
280316

281317
# dual loop
282-
while all(stopping_criterion(iter_dual, current_state_dual, training_step_loop_dual.hyperparameters) for stopping_criterion in training_step_loop_dual.stopping_criteria)
318+
while all(
319+
stopping_criterion(
320+
iter_dual,
321+
current_state_dual,
322+
training_step_loop_dual.hyperparameters,
323+
) for stopping_criterion in training_step_loop_dual.stopping_criteria
324+
)
283325
current_states_dual = Vector{NamedTuple}(undef, num_batches)
284326
iter_batch = 1
285327
for (θ) in data
286328
_, loss_val, stats, train_state_dual = Training.single_train_step!(
287329
AutoZygote(), # AD backend
288330
training_step_loop_dual.loss_fn, # Loss function
289-
(θ, training_step_loop_dual.hyperparameters, training_step_loop_dual.pre_hook(θ, primal_model, train_state_primal, dual_model, train_state_dual)...), # Data
290-
train_state_dual # Training state
331+
(
332+
θ,
333+
training_step_loop_dual.hyperparameters,
334+
training_step_loop_dual.pre_hook(
335+
θ,
336+
primal_model,
337+
train_state_primal,
338+
dual_model,
339+
train_state_dual,
340+
)...,
341+
), # Data
342+
train_state_dual, # Training state
291343
)
292344
current_states_dual[iter_batch] = stats
293345
iter_batch += 1
@@ -296,7 +348,11 @@ function L2OALM_epoch!(
296348
iter_dual += 1
297349
end
298350
for fn in training_step_loop_dual.parameter_update_fns
299-
fn(training_step_loop_primal.hyperparameters, training_step_loop_dual.hyperparameters, current_state_dual)
351+
fn(
352+
training_step_loop_primal.hyperparameters,
353+
training_step_loop_dual.hyperparameters,
354+
current_state_dual,
355+
)
300356
end
301357
return
302358
end
@@ -334,20 +390,33 @@ function L2OALM_train!(
334390
train_state_primal::Lux.Training.TrainState,
335391
train_state_dual::Lux.Training.TrainState,
336392
data;
337-
training_step_loop_primal::TrainingStepLoop=_default_primal_loop(bm),
338-
training_step_loop_dual::TrainingStepLoop=_default_dual_loop(num_equal),
339-
stopping_criteria::Vector{F}=[(iter, primal_model, dual_model, tr_st_primal, tr_st_dual, hpm_primal) -> iter >= 100 ? true : false],
340-
) where F<:Function
393+
training_step_loop_primal::TrainingStepLoop = _default_primal_loop(bm),
394+
training_step_loop_dual::TrainingStepLoop = _default_dual_loop(num_equal),
395+
stopping_criteria::Vector{F} = [
396+
(iter, primal_model, dual_model, tr_st_primal, tr_st_dual, hpm_primal) ->
397+
iter >= 100 ? true : false,
398+
],
399+
) where {F<:Function}
341400
iter = 1
342-
while all(stopping_criterion(iter, primal_model, dual_model, train_state_primal, train_state_dual, training_step_loop_primal.hyperparameters, training_step_loop_dual.hyperparameters) for stopping_criterion in stopping_criteria)
401+
while all(
402+
stopping_criterion(
403+
iter,
404+
primal_model,
405+
dual_model,
406+
train_state_primal,
407+
train_state_dual,
408+
training_step_loop_primal.hyperparameters,
409+
training_step_loop_dual.hyperparameters,
410+
) for stopping_criterion in stopping_criteria
411+
)
343412
L2OALM_epoch!(
344413
primal_model,
345414
train_state_primal,
346415
dual_model,
347416
train_state_dual,
348417
training_step_loop_primal,
349418
training_step_loop_dual,
350-
data
419+
data,
351420
)
352421
iter += 1
353422
end

0 commit comments

Comments
 (0)