Skip to content

Conversation

andrewrosemberg
Copy link
Contributor

Adds Augmented Lagrangian Primal-Dual Learning Method

@andrewrosemberg andrewrosemberg self-assigned this Jul 18, 2025
@andrewrosemberg andrewrosemberg changed the title WIP: Augmented Lagrangian Learning Method Augmented Lagrangian Learning Method Jul 22, 2025
@andrewrosemberg andrewrosemberg requested a review from klamike July 22, 2025 22:06
@andrewrosemberg andrewrosemberg added the enhancement New feature or request label Jul 22, 2025
Copy link

@klamike klamike left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work! I have a few comments, nothing major

Project.toml Outdated
Comment on lines 6 to 13
[deps]
BatchNLPKernels = "7145f916-0e30-4c9d-93a2-b32b6056125d"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ExaModels = "1037b233-b668-4ce9-9b63-f9f681f55dd2"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need all of these? In particular CUDA, LuxCUDA, ExaModels?

test/runtests.jl Outdated
train_state_dual,
data,
stopping_criteria = [validation_testset],
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

src/L2OALM.jl Outdated
Keywords:
- `max_dual`: Maximum value for the target dual variables.
"""
function LagrangianDualLoss(num_equal::Int; max_dual = 1e6)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note, we should probably eventually have a somewhat standard interface for L2OMethods and their hyperparameters, i.e.

struct ALMMethod <: AbstractL2OMethod
  bm::BatchModel
  max_dual::Float64
  ρ_init::Float64
end

or

struct ALMMethod <: AbstractL2OMethod
  bm::BatchModel
  hyperparameters::Dict{Symbol,Any}
end

Ideally that also would help to clean up stuff like

L2OALM.jl/src/L2OALM.jl

Lines 196 to 197 in 0797bb7

hpm_primal[] = min(hpm_primal[:ρmax], hpm_primal[] * hpm_primal[])
hpm_dual[] = hpm_primal[] # Ensure dual model uses the same ρ

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mutable struct PrimalDualTrainer
    primal_model::Lux
    primal_training_state::
    dual_model::Lux
    dual_training_state::
    data::Dataloader
    


nvar = model.meta.nvar
ncon = model.meta.ncon
nθ = length(model.θ)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is such a common thing, BNK should probably have a field for , and expose a frontend like num_parameters, num_variables, num_constraints.

test/runtests.jl Outdated
Comment on lines 124 to 127
gh_bound = gh_test[1:end-num_equal, :]
gh_equal = gh_test[end-num_equal+1:end, :]
dual_hat_bound = dual_hat[1:end-num_equal, :]
dual_hat_equal = dual_hat[end-num_equal+1:end, :]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another obvious thing BNK should have -- functions that help you deal with indices

src/L2OALM.jl Outdated
Comment on lines 212 to 218
Dict{Symbol,Any}(
:ρ => 1.0,
:ρmax => 1e6,
:τ => 0.8,
:α => 10.0,
:max_violation => 0.0,
),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Lux let you make these structs instead?

src/L2OALM.jl Outdated
Default function that reconciles the state of the dual model after processing a batch of data.
This function computes the mean dual loss from the batch states.
"""
function _reconcile_alm_dual_state(batch_states::Vector{NamedTuple})
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function _reconcile_alm_dual_state(batch_states::Vector{NamedTuple})
function _reconcile_dual_state(batch_states::Vector{NamedTuple})

alm can be removed since that is this whole repo 😄 (needs updates everywhere else, and for the primal version, update_rho, etc. too. let me know if you agree and I can add that commit)

src/L2OALM.jl Outdated
Comment on lines 230 to 239
function _default_dual_loop(num_equal::Int)
return TrainingStepLoop(
LagrangianDualLoss(num_equal),
[(iter, current_state, hpm) -> iter >= 100 ? true : false],
Dict{Symbol,Any}(:max_dual => 1e6, :ρ => 1.0),
[],
_reconcile_alm_dual_state,
_pre_hook_dual,
)
end
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about exposing the hyperparameters as kwargs here? Same for the primal one.

src/L2OALM.jl Outdated
Comment on lines 277 to 280
stopping_criterion(
iter_primal,
current_state_primal,
training_step_loop_primal.hyperparameters,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this some standard Lux API? our stopping criteria don't need the state nor hyperparameters

src/L2OALM.jl Outdated
Comment on lines 115 to 127
function _pre_hook_primal(
θ,
primal_model,
train_state_primal,
dual_model,
train_state_dual,
bm,
)
# Forward pass for dual model
dual_hat_k, _ = dual_model(θ, train_state_dual.parameters, train_state_dual.states)

return (dual_hat_k,)
end
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some guidance for what to put in a "pre-hook" vs "loss" ? Do they get treated differently somehow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep hooks but move primal and dual evaluation inside loop with Chainerules.ignore_derivatives

src/L2OALM.jl Outdated
Comment on lines 100 to 107
mutable struct TrainingStepLoop
loss_fn::Function
stopping_criteria::Vector{Function}
hyperparameters::Dict{Symbol,Any}
parameter_update_fns::Vector{Function}
reconcile_state::Function
pre_hook::Function
end
Copy link

@klamike klamike Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why Vector {Function} for parameter_update_fns, stopping_criteria? I think it only ever uses one.

I see for the dual case there is no parameter_update_fn. I guess(x...) -> nothing can work there..

Θ_train = randn(T, nθ, dataset_size) |> dev_gpu
Θ_test = randn(T, nθ, dataset_size) |> dev_gpu

primal_model = feed_forward_builder(nθ, nvar, [320, 320])
Copy link

@klamike klamike Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure where but we should eventually have some magic for this... something like L2ONN.feed_forward(bm, input=:all_params, output=:all_vars, hidden_sizes=[320,320])

Comment on lines +63 to +64
bm_train = BNK.BatchModel(model, batch_size, config = BNK.BatchModelConfig(:full))
bm_test = BNK.BatchModel(model, dataset_size, config = BNK.BatchModelConfig(:full))
Copy link

@klamike klamike Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bm_train = BNK.BatchModel(model, batch_size, config = BNK.BatchModelConfig(:full))
bm_test = BNK.BatchModel(model, dataset_size, config = BNK.BatchModelConfig(:full))
bm_train = BNK.BatchModel(model, batch_size, config = BNK.BatchModelConfig(:viol_grad))
bm_test = BNK.BatchModel(model, dataset_size, config = BNK.BatchModelConfig(:viol_grad))

viol_grad suffices, to avoid jprod and hessian storage

Project.toml Outdated
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[sources]
BatchNLPKernels = {url = "https://github.com/klamike/BatchNLPKernels.jl"}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
BatchNLPKernels = {url = "https://github.com/klamike/BatchNLPKernels.jl"}
BatchNLPKernels = {url = "https://github.com/LearningToOptimize/BatchNLPKernels.jl"}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants