@@ -9,8 +9,8 @@ using Lux.Training
99using  CUDA
1010using  Statistics
1111
12- export  LagrangianDualLoss, LagrangianPrimalLoss, TrainingStepLoop, 
13-         L2OALM_epoch!, L2OALM_train!
12+ export  LagrangianDualLoss,
13+     LagrangianPrimalLoss, TrainingStepLoop,  L2OALM_epoch!, L2OALM_train!
1414
1515""" 
1616    LagrangianDualLoss(;max_dual=1e6) 
@@ -23,27 +23,27 @@ Target dual variables are clipped from zero to `max_dual`.
2323Keywords: 
2424    - `max_dual`: Maximum value for the target dual variables. 
2525""" 
26- function  LagrangianDualLoss (num_equal:: Int ; max_dual= 1e6 )
26+ function  LagrangianDualLoss (num_equal:: Int ; max_dual  =   1e6 )
2727    return  (dual_model, ps_dual, st_dual, data) ->  begin 
2828        x, hpm, dual_hat_k, gh =  data
2929        ρ =  hpm[:ρ ]
3030        #  Get current dual predictions
3131        dual_hat, st_dual_new =  dual_model (x, ps_dual, st_dual)
32-          
32+ 
3333        #  Separate bound and equality constraints
34-         gh_bound =  gh[1 : end - num_equal,:]
35-         gh_equal =  gh[end - num_equal+ 1 : end ,:]
36-         dual_hat_bound =  dual_hat_k[1 : end - num_equal,:]
37-         dual_hat_equal =  dual_hat_k[end - num_equal+ 1 : end ,:]
38-          
34+         gh_bound =  gh[1 : end - num_equal,  :]
35+         gh_equal =  gh[end - num_equal+ 1 : end ,  :]
36+         dual_hat_bound =  dual_hat_k[1 : end - num_equal,  :]
37+         dual_hat_equal =  dual_hat_k[end - num_equal+ 1 : end ,  :]
38+ 
3939        #  Target for dual variables
4040        dual_target =  vcat (
4141            min .(max .(dual_hat_bound +  ρ .*  gh_bound, 0 ), max_dual),
42-             min .(dual_hat_equal +  ρ .*  gh_equal, max_dual)
42+             min .(dual_hat_equal +  ρ .*  gh_equal, max_dual), 
4343        )
44-          
45-         loss =  mean ((dual_hat .-  dual_target). ^ 2 )
46-         return  loss, st_dual_new, (dual_loss= loss,)
44+ 
45+         loss =  mean ((dual_hat .-  dual_target)  .^   2 )
46+         return  loss, st_dual_new, (dual_loss  =   loss,)
4747    end 
4848end 
4949
@@ -56,31 +56,31 @@ from current dual predictions `dual_hat` for the batch model `bm` under paramete
5656Arguments: 
5757    - `bm`: A `BatchModel` instance that contains the model and batch configuration. 
5858""" 
59- function  LagrangianPrimalLoss (bm:: BatchModel )     
59+ function  LagrangianPrimalLoss (bm:: BatchModel )
6060    return  (model, ps, st, data) ->  begin 
6161        Θ, hpm, dual_hat =  data
6262        ρ =  hpm[:ρ ]
6363        num_s =  size (Θ, 2 )
6464
6565        #  Forward pass for prediction
6666        X̂, st_new =  model (Θ, ps, st)
67-          
67+ 
6868        #  Calculate violations and objectives
6969        objs =  BNK. objective! (bm, X̂, Θ)
7070        #  gh = constraints!(bm, X̂, Θ)
7171        Vc, Vb =  BNK. all_violations! (bm, X̂, Θ)
7272        V =  vcat (Vb, Vc)
7373        total_loss =  (
74-             sum (abs .(dual_hat .*  V)) /  num_s + 
75-             ρ /  2  *  sum ((V). ^ 2 ) /  num_s + 
76-             mean (objs)
74+             sum (abs .(dual_hat .*  V)) /  num_s +  ρ /  2  *  sum ((V) .^  2 ) /  num_s +  mean (objs)
7775        )
7876
79-         return  total_loss, st_new, (
80-             total_loss= total_loss,
81-             mean_violations= mean (V),
82-             new_max_violation= maximum (V),
83-             mean_objs= mean (objs),
77+         return  total_loss,
78+         st_new,
79+         (
80+             total_loss =  total_loss,
81+             mean_violations =  mean (V),
82+             new_max_violation =  maximum (V),
83+             mean_objs =  mean (objs),
8484        )
8585    end 
8686end 
@@ -100,7 +100,7 @@ Fields:
100100mutable struct  TrainingStepLoop
101101    loss_fn:: Function 
102102    stopping_criteria:: Vector{Function} 
103-     hyperparameters:: Dict{Symbol,  Any} 
103+     hyperparameters:: Dict{Symbol,Any} 
104104    parameter_update_fns:: Vector{Function} 
105105    reconcile_state:: Function 
106106    pre_hook:: Function 
112112Default pre-hook function for the primal model in the L2O-ALM algorithm. 
113113This function performs a forward pass through the dual model to obtain the dual predictions. 
114114""" 
115- function  _pre_hook_primal (θ, primal_model, train_state_primal, dual_model, train_state_dual, bm)
115+ function  _pre_hook_primal (
116+     θ,
117+     primal_model,
118+     train_state_primal,
119+     dual_model,
120+     train_state_dual,
121+     bm,
122+ )
116123    #  Forward pass for dual model
117124    dual_hat_k, _ =  dual_model (θ, train_state_dual. parameters, train_state_dual. states)
118125
@@ -125,15 +132,22 @@ end
125132Default pre-hook function for the dual model in the L2O-ALM algorithm. 
126133This function performs a forward pass through the primal model to obtain the predicted state and constraints. 
127134""" 
128- function  _pre_hook_dual (θ, primal_model, train_state_primal, dual_model, train_state_dual, bm)
135+ function  _pre_hook_dual (
136+     θ,
137+     primal_model,
138+     train_state_primal,
139+     dual_model,
140+     train_state_dual,
141+     bm,
142+ )
129143    #  # Forward pass for primal model
130144    X̂, _ =  primal_model (θ, train_state_primal. parameters, train_state_primal. states)
131145    gh =  constraints! (bm, X̂, Θ)
132-      
146+ 
133147    #  Forward pass for dual model
134148    dual_hat, _ =  dual_model (θ, train_state_dual. parameters, train_state_dual. states)
135149
136-     return  (dual_hat, gh, )
150+     return  (dual_hat, gh)
137151end 
138152
139153""" 
@@ -149,10 +163,10 @@ function _reconcile_alm_primal_state(batch_states::Vector{NamedTuple})
149163    mean_objs =  mean ([s. mean_objs for  s in  batch_states])
150164    mean_loss =  mean ([s. total_loss for  s in  batch_states])
151165    return  (;
152-         new_max_violation= max_violation,
153-         mean_violations= mean_violations,
154-         mean_objs= mean_objs,
155-         total_loss= mean_loss,
166+         new_max_violation  =   max_violation,
167+         mean_violations  =   mean_violations,
168+         mean_objs  =   mean_objs,
169+         total_loss  =   mean_loss,
156170    )
157171end 
158172
@@ -164,7 +178,7 @@ This function computes the mean dual loss from the batch states.
164178""" 
165179function  _reconcile_alm_dual_state (batch_states:: Vector{NamedTuple} )
166180    dual_loss =  mean ([s. dual_loss for  s in  batch_states])
167-     return  (dual_loss= dual_loss,)
181+     return  (dual_loss  =   dual_loss,)
168182end 
169183
170184""" 
173187Default function to update the hyperparameter ρ in the ALM algorithm. 
174188This function increases ρ by a factor of α if the new maximum violation exceeds τ times the previous maximum violation. 
175189""" 
176- function  _update_ALM_ρ! (hpm_primal:: Dict{Symbol, Any} , hpm_dual:: Dict{Symbol, Any} , current_state:: NamedTuple )
190+ function  _update_ALM_ρ! (
191+     hpm_primal:: Dict{Symbol,Any} ,
192+     hpm_dual:: Dict{Symbol,Any} ,
193+     current_state:: NamedTuple ,
194+ )
177195    if  current_state. new_max_violation >  hpm_primal. τ *  hpm_primal. max_violation
178196        hpm_primal[:ρ ] =  min (hpm_primal[:ρmax ], hpm_primal[:ρ ] *  hpm_primal[:α ])
179197        hpm_dual[:ρ ] =  hpm_primal[:ρ ]  #  Ensure dual model uses the same ρ
@@ -191,7 +209,7 @@ function _default_primal_loop(bm::BatchModel)
191209    return  TrainingStepLoop (
192210        LagrangianPrimalLoss (bm),
193211        [(iter, current_state, hpm) ->  iter >=  100  ?  true  :  false ],
194-         Dict {Symbol,  Any} (
212+         Dict {Symbol,Any} (
195213            :ρ  =>  1.0 ,
196214            :ρmax  =>  1e6 ,
197215            :τ  =>  0.8 ,
@@ -200,7 +218,7 @@ function _default_primal_loop(bm::BatchModel)
200218        ),
201219        [_update_ALM_ρ!],
202220        _reconcile_alm_primal_state,
203-         _pre_hook_primal
221+         _pre_hook_primal, 
204222    )
205223end 
206224
@@ -213,13 +231,10 @@ function _default_dual_loop(num_equal::Int)
213231    return  TrainingStepLoop (
214232        LagrangianDualLoss (num_equal),
215233        [(iter, current_state, hpm) ->  iter >=  100  ?  true  :  false ],
216-         Dict {Symbol, Any} (
217-             :max_dual  =>  1e6 ,
218-             :ρ  =>  1.0 ,
219-         ),
234+         Dict {Symbol,Any} (:max_dual  =>  1e6 , :ρ  =>  1.0 ),
220235        [],
221236        _reconcile_alm_dual_state,
222-         _pre_hook_dual
237+         _pre_hook_dual, 
223238    )
224239end 
225240
@@ -249,7 +264,7 @@ function L2OALM_epoch!(
249264    train_state_dual:: Lux.Training.TrainState ,
250265    training_step_loop_primal:: TrainingStepLoop ,
251266    training_step_loop_dual:: TrainingStepLoop ,
252-     data
267+     data, 
253268)
254269    iter_primal =  1 
255270    iter_dual =  1 
@@ -258,36 +273,73 @@ function L2OALM_epoch!(
258273    current_state_dual =  (;)
259274
260275    #  primal loop
261-     while  all (stopping_criterion (iter_primal, current_state_primal, training_step_loop_primal. hyperparameters) for  stopping_criterion in  training_step_loop_primal. stopping_criteria)
276+     while  all (
277+         stopping_criterion (
278+             iter_primal,
279+             current_state_primal,
280+             training_step_loop_primal. hyperparameters,
281+         ) for  stopping_criterion in  training_step_loop_primal. stopping_criteria
282+     )
262283        current_states_primal =  Vector {NamedTuple} (undef, num_batches)
263284        iter_batch =  1 
264285        for  (θ) in  data
265286            _, loss_val, stats, train_state_primal =  Training. single_train_step! (
266287                AutoZygote (),          #  AD backend
267288                training_step_loop_primal. loss_fn,  #  Loss function
268-                 (θ, training_step_loop_primal. hyperparameters, training_step_loop_primal. pre_hook (θ, primal_model, train_state_primal, dual_model, train_state_dual)... ), #  Data
269-                 train_state_primal #  Training state
289+                 (
290+                     θ,
291+                     training_step_loop_primal. hyperparameters,
292+                     training_step_loop_primal. pre_hook (
293+                         θ,
294+                         primal_model,
295+                         train_state_primal,
296+                         dual_model,
297+                         train_state_dual,
298+                     )... ,
299+                 ), #  Data
300+                 train_state_primal, #  Training state
270301            )
271302            current_states_primal[iter_batch] =  stats
272303            iter_batch +=  1 
273304        end 
274-         current_state_primal =  training_step_loop_primal. reconcile_state (current_states_primal)
305+         current_state_primal = 
306+             training_step_loop_primal. reconcile_state (current_states_primal)
275307        iter_primal +=  1 
276308    end 
277309    for  fn in  training_step_loop_primal. parameter_update_fns
278-         fn (training_step_loop_primal. hyperparameters, training_step_loop_dual. hyperparameters, current_state_primal)
310+         fn (
311+             training_step_loop_primal. hyperparameters,
312+             training_step_loop_dual. hyperparameters,
313+             current_state_primal,
314+         )
279315    end 
280316
281317    #  dual loop
282-     while  all (stopping_criterion (iter_dual, current_state_dual, training_step_loop_dual. hyperparameters) for  stopping_criterion in  training_step_loop_dual. stopping_criteria)
318+     while  all (
319+         stopping_criterion (
320+             iter_dual,
321+             current_state_dual,
322+             training_step_loop_dual. hyperparameters,
323+         ) for  stopping_criterion in  training_step_loop_dual. stopping_criteria
324+     )
283325        current_states_dual =  Vector {NamedTuple} (undef, num_batches)
284326        iter_batch =  1 
285327        for  (θ) in  data
286328            _, loss_val, stats, train_state_dual =  Training. single_train_step! (
287329                AutoZygote (),          #  AD backend
288330                training_step_loop_dual. loss_fn,  #  Loss function
289-                 (θ, training_step_loop_dual. hyperparameters, training_step_loop_dual. pre_hook (θ, primal_model, train_state_primal, dual_model, train_state_dual)... ), #  Data
290-                 train_state_dual #  Training state
331+                 (
332+                     θ,
333+                     training_step_loop_dual. hyperparameters,
334+                     training_step_loop_dual. pre_hook (
335+                         θ,
336+                         primal_model,
337+                         train_state_primal,
338+                         dual_model,
339+                         train_state_dual,
340+                     )... ,
341+                 ), #  Data
342+                 train_state_dual, #  Training state
291343            )
292344            current_states_dual[iter_batch] =  stats
293345            iter_batch +=  1 
@@ -296,7 +348,11 @@ function L2OALM_epoch!(
296348        iter_dual +=  1 
297349    end 
298350    for  fn in  training_step_loop_dual. parameter_update_fns
299-         fn (training_step_loop_primal. hyperparameters, training_step_loop_dual. hyperparameters, current_state_dual)
351+         fn (
352+             training_step_loop_primal. hyperparameters,
353+             training_step_loop_dual. hyperparameters,
354+             current_state_dual,
355+         )
300356    end 
301357    return 
302358end 
@@ -334,20 +390,33 @@ function L2OALM_train!(
334390    train_state_primal:: Lux.Training.TrainState ,
335391    train_state_dual:: Lux.Training.TrainState ,
336392    data;
337-     training_step_loop_primal:: TrainingStepLoop = _default_primal_loop (bm),
338-     training_step_loop_dual:: TrainingStepLoop = _default_dual_loop (num_equal),
339-     stopping_criteria:: Vector{F} = [(iter, primal_model, dual_model, tr_st_primal, tr_st_dual, hpm_primal) ->  iter >=  100  ?  true  :  false ],
340- ) where  F<: Function 
393+     training_step_loop_primal:: TrainingStepLoop  =  _default_primal_loop (bm),
394+     training_step_loop_dual:: TrainingStepLoop  =  _default_dual_loop (num_equal),
395+     stopping_criteria:: Vector{F}  =  [
396+         (iter, primal_model, dual_model, tr_st_primal, tr_st_dual, hpm_primal) -> 
397+             iter >=  100  ?  true  :  false ,
398+     ],
399+ ) where  {F<: Function }
341400    iter =  1 
342-     while  all (stopping_criterion (iter, primal_model, dual_model, train_state_primal, train_state_dual, training_step_loop_primal. hyperparameters, training_step_loop_dual. hyperparameters) for  stopping_criterion in  stopping_criteria)
401+     while  all (
402+         stopping_criterion (
403+             iter,
404+             primal_model,
405+             dual_model,
406+             train_state_primal,
407+             train_state_dual,
408+             training_step_loop_primal. hyperparameters,
409+             training_step_loop_dual. hyperparameters,
410+         ) for  stopping_criterion in  stopping_criteria
411+     )
343412        L2OALM_epoch! (
344413            primal_model,
345414            train_state_primal,
346415            dual_model,
347416            train_state_dual,
348417            training_step_loop_primal,
349418            training_step_loop_dual,
350-             data
419+             data, 
351420        )
352421        iter +=  1 
353422    end 
0 commit comments