-
Notifications
You must be signed in to change notification settings - Fork 0
Augmented Lagrangian Learning Method #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work! I have a few comments, nothing major
Project.toml
Outdated
[deps] | ||
BatchNLPKernels = "7145f916-0e30-4c9d-93a2-b32b6056125d" | ||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||
ExaModels = "1037b233-b668-4ce9-9b63-f9f681f55dd2" | ||
Lux = "b2108857-7c20-44ae-9111-449ecde12c47" | ||
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" | ||
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need all of these? In particular CUDA, LuxCUDA, ExaModels?
test/runtests.jl
Outdated
train_state_dual, | ||
data, | ||
stopping_criteria = [validation_testset], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/L2OALM.jl
Outdated
Keywords: | ||
- `max_dual`: Maximum value for the target dual variables. | ||
""" | ||
function LagrangianDualLoss(num_equal::Int; max_dual = 1e6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a note, we should probably eventually have a somewhat standard interface for L2OMethods and their hyperparameters, i.e.
struct ALMMethod <: AbstractL2OMethod
bm::BatchModel
max_dual::Float64
ρ_init::Float64
end
or
struct ALMMethod <: AbstractL2OMethod
bm::BatchModel
hyperparameters::Dict{Symbol,Any}
end
Ideally that also would help to clean up stuff like
Lines 196 to 197 in 0797bb7
hpm_primal[:ρ] = min(hpm_primal[:ρmax], hpm_primal[:ρ] * hpm_primal[:α]) | |
hpm_dual[:ρ] = hpm_primal[:ρ] # Ensure dual model uses the same ρ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mutable struct PrimalDualTrainer
primal_model::Lux
primal_training_state::
dual_model::Lux
dual_training_state::
data::Dataloader
|
||
nvar = model.meta.nvar | ||
ncon = model.meta.ncon | ||
nθ = length(model.θ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is such a common thing, BNK should probably have a field for nθ
, and expose a frontend like num_parameters
, num_variables
, num_constraints
.
test/runtests.jl
Outdated
gh_bound = gh_test[1:end-num_equal, :] | ||
gh_equal = gh_test[end-num_equal+1:end, :] | ||
dual_hat_bound = dual_hat[1:end-num_equal, :] | ||
dual_hat_equal = dual_hat[end-num_equal+1:end, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is another obvious thing BNK should have -- functions that help you deal with indices
src/L2OALM.jl
Outdated
Dict{Symbol,Any}( | ||
:ρ => 1.0, | ||
:ρmax => 1e6, | ||
:τ => 0.8, | ||
:α => 10.0, | ||
:max_violation => 0.0, | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does Lux let you make these structs instead?
src/L2OALM.jl
Outdated
Default function that reconciles the state of the dual model after processing a batch of data. | ||
This function computes the mean dual loss from the batch states. | ||
""" | ||
function _reconcile_alm_dual_state(batch_states::Vector{NamedTuple}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function _reconcile_alm_dual_state(batch_states::Vector{NamedTuple}) | |
function _reconcile_dual_state(batch_states::Vector{NamedTuple}) |
alm
can be removed since that is this whole repo 😄 (needs updates everywhere else, and for the primal version, update_rho, etc. too. let me know if you agree and I can add that commit)
src/L2OALM.jl
Outdated
function _default_dual_loop(num_equal::Int) | ||
return TrainingStepLoop( | ||
LagrangianDualLoss(num_equal), | ||
[(iter, current_state, hpm) -> iter >= 100 ? true : false], | ||
Dict{Symbol,Any}(:max_dual => 1e6, :ρ => 1.0), | ||
[], | ||
_reconcile_alm_dual_state, | ||
_pre_hook_dual, | ||
) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about exposing the hyperparameters as kwargs here? Same for the primal one.
src/L2OALM.jl
Outdated
stopping_criterion( | ||
iter_primal, | ||
current_state_primal, | ||
training_step_loop_primal.hyperparameters, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this some standard Lux API? our stopping criteria don't need the state nor hyperparameters
src/L2OALM.jl
Outdated
function _pre_hook_primal( | ||
θ, | ||
primal_model, | ||
train_state_primal, | ||
dual_model, | ||
train_state_dual, | ||
bm, | ||
) | ||
# Forward pass for dual model | ||
dual_hat_k, _ = dual_model(θ, train_state_dual.parameters, train_state_dual.states) | ||
|
||
return (dual_hat_k,) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there some guidance for what to put in a "pre-hook" vs "loss" ? Do they get treated differently somehow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keep hooks but move primal and dual evaluation inside loop with Chainerules.ignore_derivatives
src/L2OALM.jl
Outdated
mutable struct TrainingStepLoop | ||
loss_fn::Function | ||
stopping_criteria::Vector{Function} | ||
hyperparameters::Dict{Symbol,Any} | ||
parameter_update_fns::Vector{Function} | ||
reconcile_state::Function | ||
pre_hook::Function | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why Vector {Function}
for parameter_update_fns
, stopping_criteria
? I think it only ever uses one.
I see for the dual case there is no parameter_update_fn
. I guess(x...) -> nothing
can work there..
Θ_train = randn(T, nθ, dataset_size) |> dev_gpu | ||
Θ_test = randn(T, nθ, dataset_size) |> dev_gpu | ||
|
||
primal_model = feed_forward_builder(nθ, nvar, [320, 320]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure where but we should eventually have some magic for this... something like L2ONN.feed_forward(bm, input=:all_params, output=:all_vars, hidden_sizes=[320,320])
bm_train = BNK.BatchModel(model, batch_size, config = BNK.BatchModelConfig(:full)) | ||
bm_test = BNK.BatchModel(model, dataset_size, config = BNK.BatchModelConfig(:full)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bm_train = BNK.BatchModel(model, batch_size, config = BNK.BatchModelConfig(:full)) | |
bm_test = BNK.BatchModel(model, dataset_size, config = BNK.BatchModelConfig(:full)) | |
bm_train = BNK.BatchModel(model, batch_size, config = BNK.BatchModelConfig(:viol_grad)) | |
bm_test = BNK.BatchModel(model, dataset_size, config = BNK.BatchModelConfig(:viol_grad)) |
viol_grad
suffices, to avoid jprod and hessian storage
Project.toml
Outdated
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
|
||
[sources] | ||
BatchNLPKernels = {url = "https://github.com/klamike/BatchNLPKernels.jl"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BatchNLPKernels = {url = "https://github.com/klamike/BatchNLPKernels.jl"} | |
BatchNLPKernels = {url = "https://github.com/LearningToOptimize/BatchNLPKernels.jl"} |
Adds Augmented Lagrangian Primal-Dual Learning Method