From 07324830e0a8754fcc15d15036f527b734c18269 Mon Sep 17 00:00:00 2001 From: Andrew Date: Fri, 18 Jul 2025 16:16:55 -0400 Subject: [PATCH 01/15] add primal dual losses --- src/L2OALM.jl | 75 +++++++++++++++++- test/power.jl | 199 +++++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 101 +++++++++++++++++++++++- 3 files changed, 373 insertions(+), 2 deletions(-) create mode 100644 test/power.jl diff --git a/src/L2OALM.jl b/src/L2OALM.jl index 33126a5..4f629ef 100644 --- a/src/L2OALM.jl +++ b/src/L2OALM.jl @@ -1,5 +1,78 @@ module L2OALM -# Write your package code here. +using BatchNLPKernels +using ExaModels +using KernelAbstractions +using DifferentiationInterface +const DI = DifferentiationInterface + +import Zygote +import FiniteDifferences + +using PowerModels +PowerModels.silence() +using PGLib +using LinearAlgebra + +using Lux +using LuxCUDA +using Lux.Training +using MLUtils +using Optimisers +using CUDA +using Random +import GPUArraysCore: @allowscalar + +using OpenCL, pocl_jll, AcceleratedKernels + +function LagrangianDualLoss(;max_dual=1e6) + return (dual_model, ps_dual, st_dual, data) -> begin + x, dual_hat_k, gh, ρ = data + + # Get current dual predictions + dual_hat, st_dual_new = dual_model(x, ps_dual, st_dual) + + # Separate bound and equality constraints + gh_bound = gh[1:end-n_bus*2,:] + gh_equal = gh[end-n_bus*2+1:end,:] + dual_hat_bound = dual_hat_k[1:end-n_bus*2,:] + dual_hat_equal = dual_hat_k[end-n_bus*2+1:end,:] + + # Target for dual variables + dual_target = vcat( + min.(max.(dual_hat_bound + ρ .* gh_bound, 0), max_dual), + min.(dual_hat_equal + ρ .* gh_equal, max_dual) + ) + + loss = mean((dual_hat .- dual_target).^2) + return loss, st_dual_new, (dual_loss=loss,) + end +end + +function LagrangianPrimalLoss(bm::BatchModel) + return (model, ps, st, data) -> begin + Θ, dual_hat, ρ = data + num_s = size(Θ, 2) + + # Forward pass for prediction + X̂, st_new = model(Θ, ps, st) + + # Calculate violations and objectives + objs = BNK.objective!(bm, X̂, Θ) + # gh = constraints!(bm, X̂, Θ) + Vc, Vb = BNK.all_violations!(bm, X̂, Θ) + V = vcat(Vb, Vc) + total_loss = ( + sum(abs.(dual_hat .* V)) / num_s + + ρ / 2 * sum((V).^2) / num_s + + mean(objs) + ) + + return total_loss, st_new, ( + total_loss=total_loss, + ) + end +end + end diff --git a/test/power.jl b/test/power.jl new file mode 100644 index 0000000..545e707 --- /dev/null +++ b/test/power.jl @@ -0,0 +1,199 @@ +function _get_case_file(filename::String) + isfile(filename) && return filename + + cached = joinpath(PGLib.PGLib_opf, filename) + !isfile(cached) && error("File $filename not found in PGLib/pglib-opf") + return cached +end +function _build_power_ref(filename) + path = _get_case_file(filename) + data = PowerModels.parse_file(path) + PowerModels.standardize_cost_terms!(data, order = 2) + PowerModels.calc_thermal_limits!(data) + return PowerModels.build_ref(data)[:it][:pm][:nw][0] +end + + +_convert_array(data::N, backend) where {names,N<:NamedTuple{names}} = + NamedTuple{names}(ExaModels.convert_array(d, backend) for d in data) +function _parse_ac_data_raw(filename; T=Float64) + ref = _build_power_ref(filename) # FIXME: only parse once + arcdict = Dict(a => k for (k, a) in enumerate(ref[:arcs])) + busdict = Dict(k => i for (i, (k, _)) in enumerate(ref[:bus])) + gendict = Dict(k => i for (i, (k, _)) in enumerate(ref[:gen])) + branchdict = Dict(k => i for (i, (k, _)) in enumerate(ref[:branch])) + return ( + bus = [ + begin + loads = [ref[:load][l] for l in ref[:bus_loads][k]] + shunts = [ref[:shunt][s] for s in ref[:bus_shunts][k]] + pd = T(sum(load["pd"] for load in loads; init = 0.0)) + qd = T(sum(load["qd"] for load in loads; init = 0.0)) + gs = T(sum(shunt["gs"] for shunt in shunts; init = 0.0)) + bs = T(sum(shunt["bs"] for shunt in shunts; init = 0.0)) + (i = busdict[k], pd = pd, gs = gs, qd = qd, bs = bs) + end for (k, _) in ref[:bus] + ], + gen = [ + ( + i = gendict[k], + cost1 = T(v["cost"][1]), + cost2 = T(v["cost"][2]), + cost3 = T(v["cost"][3]), + bus = busdict[v["gen_bus"]], + ) for (k, v) in ref[:gen] + ], + arc = [ + (i = k, rate_a = T(ref[:branch][l]["rate_a"]), bus = busdict[i]) for + (k, (l, i, _)) in enumerate(ref[:arcs]) + ], + branch = [ + begin + branch = branch_raw + f_idx = arcdict[i, branch["f_bus"], branch["t_bus"]] + t_idx = arcdict[i, branch["t_bus"], branch["f_bus"]] + + g, b = PowerModels.calc_branch_y(branch) + tr, ti = PowerModels.calc_branch_t(branch) + ttm = tr^2 + ti^2 + + g_fr = branch["g_fr"]; b_fr = branch["b_fr"] + g_to = branch["g_to"]; b_to = branch["b_to"] + + ( + i = branchdict[i], + j = 1, + f_idx = f_idx, + t_idx = t_idx, + f_bus = busdict[branch["f_bus"]], + t_bus = busdict[branch["t_bus"]], + c1 = T((-g * tr - b * ti) / ttm), + c2 = T((-b * tr + g * ti) / ttm), + c3 = T((-g * tr + b * ti) / ttm), + c4 = T((-b * tr - g * ti) / ttm), + c5 = T((g + g_fr) / ttm), + c6 = T((b + b_fr) / ttm), + c7 = T((g + g_to)), + c8 = T((b + b_to)), + rate_a_sq = T(branch["rate_a"]^2), + ) + end for (i, branch_raw) in ref[:branch] + ], + ref_buses = [busdict[i] for (i, _) in ref[:ref_buses]], + vmax = [T(v["vmax"]) for (_, v) in ref[:bus]], + vmin = [T(v["vmin"]) for (_, v) in ref[:bus]], + pmax = [T(v["pmax"]) for (_, v) in ref[:gen]], + pmin = [T(v["pmin"]) for (_, v) in ref[:gen]], + qmax = [T(v["qmax"]) for (_, v) in ref[:gen]], + qmin = [T(v["qmin"]) for (_, v) in ref[:gen]], + rate_a = [T(ref[:branch][l]["rate_a"]) for (l, _, _) in ref[:arcs]], + angmax = [T(b["angmax"]) for (_, b) in ref[:branch]], + angmin = [T(b["angmin"]) for (_, b) in ref[:branch]], + ) +end + +_parse_ac_data(filename) = _parse_ac_data_raw(filename) +function _parse_ac_data(filename, backend; T=Float64) + _convert_array(_parse_ac_data_raw(filename, T=T), backend) +end + +# Parametric version + +function create_parametric_ac_power_model(filename::String = "pglib_opf_case14_ieee.m"; + prod::Bool = true, backend = OpenCLBackend(), T=Float64, kwargs...) + data = _parse_ac_data(filename, backend, T=T) + c = ExaCore(T; backend = backend) + + va = variable(c, length(data.bus);) + vm = variable( + c, + length(data.bus); + start = fill!(similar(data.bus, T), 1.0), + lvar = data.vmin, + uvar = data.vmax, + ) + + pg = variable(c, length(data.gen); lvar = data.pmin, uvar = data.pmax) + qg = variable(c, length(data.gen); lvar = data.qmin, uvar = data.qmax) + + @allowscalar pd = parameter(c, [b.pd for b in data.bus]) + @allowscalar qd = parameter(c, [b.qd for b in data.bus]) + + p = variable(c, length(data.arc); lvar = -data.rate_a, uvar = data.rate_a) + q = variable(c, length(data.arc); lvar = -data.rate_a, uvar = data.rate_a) + + o = objective(c, g.cost1 * pg[g.i]^2 + g.cost2 * pg[g.i] + g.cost3 for g in data.gen) + + # Reference bus angle ------------------------------------------------------ + c1 = constraint(c, va[i] for i in data.ref_buses) + + # Branch power-flow equations --------------------------------------------- + constraint( + c, + (p[b.f_idx] - b.c5 * vm[b.f_bus]^2 - + b.c3 * (vm[b.f_bus] * vm[b.t_bus] * cos(va[b.f_bus] - va[b.t_bus])) - + b.c4 * (vm[b.f_bus] * vm[b.t_bus] * sin(va[b.f_bus] - va[b.t_bus])) for + b in data.branch), + ) + + constraint( + c, + (q[b.f_idx] + b.c6 * vm[b.f_bus]^2 + + b.c4 * (vm[b.f_bus] * vm[b.t_bus] * cos(va[b.f_bus] - va[b.t_bus])) - + b.c3 * (vm[b.f_bus] * vm[b.t_bus] * sin(va[b.f_bus] - va[b.t_bus])) for + b in data.branch), + ) + + constraint( + c, + (p[b.t_idx] - b.c7 * vm[b.t_bus]^2 - + b.c1 * (vm[b.t_bus] * vm[b.f_bus] * cos(va[b.t_bus] - va[b.f_bus])) - + b.c2 * (vm[b.t_bus] * vm[b.f_bus] * sin(va[b.t_bus] - va[b.f_bus])) for + b in data.branch), + ) + + constraint( + c, + (q[b.t_idx] + b.c8 * vm[b.t_bus]^2 + + b.c2 * (vm[b.t_bus] * vm[b.f_bus] * cos(va[b.t_bus] - va[b.f_bus])) - + b.c1 * (vm[b.t_bus] * vm[b.f_bus] * sin(va[b.t_bus] - va[b.f_bus])) for + b in data.branch), + ) + + # Angle difference limits -------------------------------------------------- + constraint( + c, + va[b.f_bus] - va[b.t_bus] for b in data.branch; lcon = data.angmin, ucon = data.angmax, + ) + + # Apparent power thermal limits ------------------------------------------- + constraint( + c, + p[b.f_idx]^2 + q[b.f_idx]^2 - b.rate_a_sq for b in data.branch; + lcon = fill!(similar(data.branch, Float64, length(data.branch)), -Inf), + ) + constraint( + c, + p[b.t_idx]^2 + q[b.t_idx]^2 - b.rate_a_sq for b in data.branch; + lcon = fill!(similar(data.branch, Float64, length(data.branch)), -Inf), + ) + + # Power balance at each bus ----------------------------------------------- + load_balance_p = constraint(c, pd[b.i] + b.gs * vm[b.i]^2 for b in data.bus) + load_balance_q = constraint(c, qd[b.i] - b.bs * vm[b.i]^2 for b in data.bus) + + # Map arc & generator variables into the bus balance equations + constraint!(c, load_balance_p, a.bus => p[a.i] for a in data.arc) + constraint!(c, load_balance_q, a.bus => q[a.i] for a in data.arc) + constraint!(c, load_balance_p, g.bus => -pg[g.i] for g in data.gen) + constraint!(c, load_balance_q, g.bus => -qg[g.i] for g in data.gen) + + return ExaModel(c; prod = prod) +end + +function create_power_models(backend = OpenCLBackend(), T=Float64) + models = ExaModel[] + push!(models, create_ac_power_model("pglib_opf_case14_ieee.m"; backend = backend)) + names = ["AC-OPF – IEEE-14"] + return models, names +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index c13f2e1..816669b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,5 +2,104 @@ using L2OALM using Test @testset "L2OALM.jl" begin - # Write your tests here. + function feed_forward_builder( + num_p::Integer, + num_y::Integer, + hidden_layers::AbstractVector{<:Integer}; + activation = relu, + ) + """ + Builds a Chain of Dense layers with Lux + """ + # Combine all layers: input size, hidden sizes, output size + layer_sizes = [num_p; hidden_layers; num_y] + + # Build up a list of Dense layers + dense_layers = Any[] + for i in 1:(length(layer_sizes)-1) + if i < length(layer_sizes) - 1 + # Hidden layers with activation + push!(dense_layers, Dense(layer_sizes[i], layer_sizes[i+1], activation)) + else + # Final layer with no activation + push!(dense_layers, Dense(layer_sizes[i], layer_sizes[i+1])) + end + end + + return Chain(dense_layers...) + end + + function test_penalty_training(; filename="pglib_opf_case14_ieee.m", dev_gpu = gpu_device(), + backend=CPU(), + batch_size = 32, + dataset_size = 3200, + rng = Random.default_rng(), + T = Float64, + ) + model = create_parametric_ac_power_model(filename; backend = backend, T=T) + bm = BNK.BatchModel(model, batch_size, config=BNK.BatchModelConfig(:full)) + bm_all = BNK.BatchModel(model, dataset_size, config=BNK.BatchModelConfig(:full)) + + function PenaltyLoss(model, ps, st, Θ) + X̂ , st_new = model(Θ, ps, st) + + obj = BNK.objective!(bm, X̂, Θ) + Vc, Vb = BNK.all_violations!(bm, X̂, Θ) + + return sum(obj) + 1000 * sum(Vc) + 1000 * sum(Vb), st_new, (;obj=sum(obj), Vc=sum(Vc), Vb=sum(Vb)) + end + + nvar = model.meta.nvar + ncon = model.meta.ncon + nθ = length(model.θ) + + Θ_train = randn(T, nθ, dataset_size) |> dev_gpu + + lux_model = feed_forward_builder(nθ, nvar, [320, 320]) + + ps_model, st_model = Lux.setup(rng, lux_model) + ps_model = ps_model |> dev_gpu + st_model = st_model |> dev_gpu + + X̂ , _ = lux_model(Θ_train, ps_model, st_model) + + y = BNK.objective!(bm_all, X̂, Θ_train) + + @test length(y) == dataset_size + Vc, Vb = BNK.all_violations!(bm_all, X̂, Θ_train) + @test size(Vc) == (ncon, dataset_size) + @test size(Vb) == (nvar, dataset_size) + + lagrangian_prev = sum(y) + 1000 * sum(Vc) + 1000 * sum(Vb) + + train_state = Training.TrainState(lux_model, ps_model, st_model, Optimisers.Adam(1e-5)) + + data = DataLoader((Θ_train); batchsize=batch_size, shuffle=true) .|> dev_gpu + for (Θ) in data + _, loss_val, stats, train_state = Training.single_train_step!( + AutoZygote(), # AD backend + PenaltyLoss, + (Θ), # data + train_state + ) + end + + X̂ , st_model = lux_model(Θ_train, ps_model, st_model) + + y = BNK.objective!(bm_all, X̂, Θ_train) + Vc, Vb = BNK.all_violations!(bm_all, X̂, Θ_train) + lagrangian_new = sum(y) + 1000 * sum(Vc) + 1000 * sum(Vb) + + @test lagrangian_new < lagrangian_prev + end + + @testset "Penalty Training" begin + backend, dev = if haskey(ENV, "BNK_TEST_CUDA") + CUDABackend(), gpu_device() + else + CPU(), cpu_device() + end + + test_penalty_training(; filename="pglib_opf_case14_ieee.m", dev_gpu = dev, backend=backend, T=Float32) + end end From e4d4845ee47d0a6f86b3cd3aa3472be6a7a2caf8 Mon Sep 17 00:00:00 2001 From: Andrew Date: Fri, 18 Jul 2025 16:18:54 -0400 Subject: [PATCH 02/15] fix CI and docs --- README.md | 8 ++++---- docs/make.jl | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index a4d73fd..a9fa2de 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ Learning To Optimize using the Augmented Lagrangian Primal-Dual Method. -[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://andrewrosemberg.github.io/L2OALM.jl/stable/) -[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://andrewrosemberg.github.io/L2OALM.jl/dev/) -[![Build Status](https://github.com/andrewrosemberg/L2OALM.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/andrewrosemberg/L2OALM.jl/actions/workflows/CI.yml?query=branch%3Amain) -[![Coverage](https://codecov.io/gh/andrewrosemberg/L2OALM.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/andrewrosemberg/L2OALM.jl) +[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://LearningToOptimize.github.io/L2OALM.jl/stable/) +[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://LearningToOptimize.github.io/L2OALM.jl/dev/) +[![Build Status](https://github.com/LearningToOptimize/L2OALM.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/LearningToOptimize/L2OALM.jl/actions/workflows/CI.yml?query=branch%3Amain) +[![Coverage](https://codecov.io/gh/LearningToOptimize/L2OALM.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/LearningToOptimize/L2OALM.jl) diff --git a/docs/make.jl b/docs/make.jl index e8c1d33..de4ea88 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -8,7 +8,7 @@ makedocs(; authors="Andrew and contributors", sitename="L2OALM.jl", format=Documenter.HTML(; - canonical="https://andrewrosemberg.github.io/L2OALM.jl", + canonical="https://LearningToOptimize.github.io/L2OALM.jl", edit_link="main", assets=String[], ), @@ -18,6 +18,6 @@ makedocs(; ) deploydocs(; - repo="github.com/andrewrosemberg/L2OALM.jl", + repo="github.com/LearningToOptimize/L2OALM.jl", devbranch="main", ) From df50359ad8afafb3a280e96268cbc1d7845022df Mon Sep 17 00:00:00 2001 From: Andrew Date: Fri, 18 Jul 2025 16:19:12 -0400 Subject: [PATCH 03/15] fix ci and docs --- docs/src/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index 03d7322..d71729e 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -4,7 +4,7 @@ CurrentModule = L2OALM # L2OALM -Documentation for [L2OALM](https://github.com/andrewrosemberg/L2OALM.jl). +Documentation for [L2OALM](https://github.com/LearningToOptimize/L2OALM.jl). ```@index ``` From 451e46f22ae87c7b758f55e2cf4e686f7db6cebb Mon Sep 17 00:00:00 2001 From: Andrew Date: Fri, 18 Jul 2025 16:29:55 -0400 Subject: [PATCH 04/15] update deps --- Project.toml | 9 +++++++++ src/L2OALM.jl | 15 +-------------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index 41d83ca..4a83842 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,15 @@ uuid = "f31bfc7b-7b5d-4cc3-b76b-1af281ce159d" authors = ["Andrew and contributors"] version = "1.0.0-DEV" +[deps] +BatchNLPKernels = "7145f916-0e30-4c9d-93a2-b32b6056125d" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ExaModels = "1037b233-b668-4ce9-9b63-f9f681f55dd2" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" + [compat] julia = "1.6.7" diff --git a/src/L2OALM.jl b/src/L2OALM.jl index 4f629ef..497fe3b 100644 --- a/src/L2OALM.jl +++ b/src/L2OALM.jl @@ -2,17 +2,6 @@ module L2OALM using BatchNLPKernels using ExaModels -using KernelAbstractions -using DifferentiationInterface -const DI = DifferentiationInterface - -import Zygote -import FiniteDifferences - -using PowerModels -PowerModels.silence() -using PGLib -using LinearAlgebra using Lux using LuxCUDA @@ -20,10 +9,8 @@ using Lux.Training using MLUtils using Optimisers using CUDA -using Random -import GPUArraysCore: @allowscalar -using OpenCL, pocl_jll, AcceleratedKernels +# using OpenCL, pocl_jll, AcceleratedKernels function LagrangianDualLoss(;max_dual=1e6) return (dual_model, ps_dual, st_dual, data) -> begin From 7ac9513477b40431643da44292d470fb868855bd Mon Sep 17 00:00:00 2001 From: Andrew Date: Mon, 21 Jul 2025 15:56:07 -0400 Subject: [PATCH 05/15] update loop alm --- src/L2OALM.jl | 173 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 170 insertions(+), 3 deletions(-) diff --git a/src/L2OALM.jl b/src/L2OALM.jl index 497fe3b..0db1a2c 100644 --- a/src/L2OALM.jl +++ b/src/L2OALM.jl @@ -12,10 +12,21 @@ using CUDA # using OpenCL, pocl_jll, AcceleratedKernels +""" + LagrangianDualLoss(;max_dual=1e6) + +Returns a function that computes the MSE loss for the (dual-)model predicting lagrangian dual variables +from constraint evaluations `gh` and the current dual predictions `dual_hat_k`. +Target is calculated using the augmented lagrangian method. +Target dual variables are clipped from zero to `max_dual`. + +Keywords: + - `max_dual`: Maximum value for the target dual variables. +""" function LagrangianDualLoss(;max_dual=1e6) return (dual_model, ps_dual, st_dual, data) -> begin - x, dual_hat_k, gh, ρ = data - + x, hpm, dual_hat_k, gh = data + ρ = hpm.ρ # Get current dual predictions dual_hat, st_dual_new = dual_model(x, ps_dual, st_dual) @@ -36,9 +47,19 @@ function LagrangianDualLoss(;max_dual=1e6) end end +""" + LagrangianPrimalLoss(bm::BatchModel) + +Returns a function that computes the augmented lagrangian primal loss +from current dual predictions `dual_hat` for the batch model `bm` under parameters `Θ`. + +Arguments: + - `bm`: A `BatchModel` instance that contains the model and batch configuration. +""" function LagrangianPrimalLoss(bm::BatchModel) return (model, ps, st, data) -> begin - Θ, dual_hat, ρ = data + Θ, hpm, dual_hat = data + ρ = hpm.ρ num_s = size(Θ, 2) # Forward pass for prediction @@ -57,9 +78,155 @@ function LagrangianPrimalLoss(bm::BatchModel) return total_loss, st_new, ( total_loss=total_loss, + mean_violations=mean(V), + new_max_violation=maximum(V), + mean_objs=mean(objs), ) end end +mutable struct TrainingStepLoop + loss_fn::Function + stopping_criteria::Vector{Function} + hyperparameters::Dict{Symbol, Any} + parameter_update_fns::Vector{Function} + reconcile_state::Function + pre_hook::Function +end + +function _pre_hook_primal(θ, primal_model, train_state_primal, dual_model, train_state_dual, bm) + # Forward pass for dual model + dual_hat_k, _ = dual_model(θ, train_state_dual.parameters, train_state_dual.state) + + return (dual_hat_k,) +end + +function _pre_hook_dual(θ, primal_model, train_state_primal, dual_model, train_state_dual, bm) + # # Forward pass for primal model + X̂, _ = primal_model(θ, train_state_primal.parameters, train_state_primal.state) + gh = constraints!(bm, X̂, Θ) + + # Forward pass for dual model + dual_hat, _ = dual_model(θ, train_state_dual.parameters, train_state_dual.state) + + return (dual_hat, gh,) +end + +function _reconcile_alm_primal_state(batch_states::Vector{NamedTuple}) + max_violation = maximum([s.new_max_violation for s in batch_states]) + mean_violations = mean([s.mean_violations for s in batch_states]) + mean_objs = mean([s.mean_objs for s in batch_states]) + mean_loss = mean([s.total_loss for s in batch_states]) + return (; + new_max_violation=max_violation, + mean_violations=mean_violations, + mean_objs=mean_objs, + total_loss=mean_loss, + ) +end + +function _reconcile_alm_dual_state(batch_states::Vector{NamedTuple}) + dual_loss = mean([s.dual_loss for s in batch_states]) + return (dual_loss=dual_loss,) +end + +function _update_ALM_ρ!(hpm::Dict{Symbol, Any}, current_state::NamedTuple) + if current_state.new_max_violation > hpm.τ * hpm.max_violation + hpm.ρ = min(hpm.ρmax, hpm.ρ * hpm.α) + end + hpm.max_violation = current_state.new_max_violation + return +end + +function _default_primal_loop(bm::BatchModel) + return TrainingStepLoop( + LagrangianPrimalLoss(bm), + [(iter, current_state, hpm) -> iter >= 100 ? true : false], + Dict{Symbol, Any}( + :ρ => 1.0, + :ρmax => 1e6, + :τ => 0.8, + :α => 10.0, + max_violation => 0.0, + ), + [_update_ALM_ρ!], + _reconcile_alm_primal_state, + _pre_hook_primal + ) +end + +function _default_dual_loop() + return TrainingStepLoop( + LagrangianDualLoss(), + [(iter, current_state, hpm) -> iter >= 100 ? true : false], + Dict{Symbol, Any}( + :max_dual => 1e6, + ), + [], + _reconcile_alm_dual_state, + _pre_hook_dual + ) +end + +function L2OALM_epoch( + bm::BatchModel, + primal_model::Lux.Model, + train_state_primal::Lux.TrainingState, + dual_model::Lux.Model, + train_state_dual::Lux.TrainingState, + training_step_loop_primal::TrainingStepLoop=_default_primal_loop(bm), + training_step_loop_dual::TrainingStepLoop=_default_dual_loop(), + data +) + iter_primal = 1 + iter_dual = 1 + num_batches = length(data) + current_state_primal = (;) + current_state_dual = (;) + + # primal loop + while all(stopping_criterion(iter_primal, current_state_primal, training_step_loop_primal.hyperparameters) for stopping_criterion in training_step_loop_primal.stopping_criteria) + current_states_primal = Vector{NamedTuple}(undef, num_batches) + iter_batch = 1 + for (θ) in data + _, loss_val, stats, train_state_primal = Training.single_train_step!( + AutoZygote(), # AD backend + training_step_loop_primal.loss_fn, # Loss function + (θ, training_step_loop_primal.hyperparameters, training_step_loop_primal.pre_hook(θ, primal_model, train_state_primal, dual_model, train_state_dual)...), # Data + train_state_primal # Training state + ) + current_states_primal[iter_batch] = stats + iter_batch += 1 + end + current_state_primal = training_step_loop_primal.reconcile_state(current_states_primal) + iter_primal += 1 + end + for fn in training_step_loop_primal.parameter_update_fns + fn(training_step_loop_primal.hyperparameters, current_state_primal) + end + + # dual loop + while all(stopping_criterion(iter_dual, current_state_dual, training_step_loop_dual.hyperparameters) for stopping_criterion in training_step_loop_dual.stopping_criteria) + current_states_dual = Vector{NamedTuple}(undef, num_batches) + iter_batch = 1 + for (θ) in data + _, loss_val, stats, train_state_dual = Training.single_train_step!( + AutoZygote(), # AD backend + training_step_loop_dual.loss_fn, # Loss function + (θ, training_step_loop_dual.hyperparameters, training_step_loop_dual.pre_hook(θ, primal_model, train_state_primal, dual_model, train_state_dual)...), # Data + train_state_dual # Training state + ) + current_states_dual[iter_batch] = stats + iter_batch += 1 + end + current_state_dual = training_step_loop_dual.reconcile_state(current_states_dual) + iter_dual += 1 + end + for fn in training_step_loop_dual.parameter_update_fns + fn(training_step_loop_dual.hyperparameters, current_state_dual) + end + return +end + end From e36c2e8fbe4a55f50e82be50811045d3b9e383c0 Mon Sep 17 00:00:00 2001 From: Andrew Date: Mon, 21 Jul 2025 16:04:37 -0400 Subject: [PATCH 06/15] add dockstrings --- src/L2OALM.jl | 72 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/src/L2OALM.jl b/src/L2OALM.jl index 0db1a2c..00f09eb 100644 --- a/src/L2OALM.jl +++ b/src/L2OALM.jl @@ -85,6 +85,18 @@ function LagrangianPrimalLoss(bm::BatchModel) end end +""" + TrainingStepLoop + +A structure to define a training step loop for the L2O-ALM algorithm. + +Fields: +- `loss_fn`: Function to compute the loss for the training step. +- `stopping_criteria`: Vector of functions that determine when to stop the training loop. +- `hyperparameters`: Dictionary of hyperparameters and hyper-states for the training step. +- `parameter_update_fns`: Vector of functions to update hyperparameters after each training step. +- `reconcile_state`: Function to reconcile the state after processing a batch of data. +""" mutable struct TrainingStepLoop loss_fn::Function stopping_criteria::Vector{Function} @@ -94,6 +106,12 @@ mutable struct TrainingStepLoop pre_hook::Function end +""" + _pre_hook_primal(θ, primal_model, train_state_primal, dual_model, train_state_dual, bm) + +Default pre-hook function for the primal model in the L2O-ALM algorithm. +This function performs a forward pass through the dual model to obtain the dual predictions. +""" function _pre_hook_primal(θ, primal_model, train_state_primal, dual_model, train_state_dual, bm) # Forward pass for dual model dual_hat_k, _ = dual_model(θ, train_state_dual.parameters, train_state_dual.state) @@ -101,6 +119,12 @@ function _pre_hook_primal(θ, primal_model, train_state_primal, dual_model, trai return (dual_hat_k,) end +""" + _pre_hook_dual(θ, primal_model, train_state_primal, dual_model, train_state_dual, bm) + +Default pre-hook function for the dual model in the L2O-ALM algorithm. +This function performs a forward pass through the primal model to obtain the predicted state and constraints. +""" function _pre_hook_dual(θ, primal_model, train_state_primal, dual_model, train_state_dual, bm) # # Forward pass for primal model X̂, _ = primal_model(θ, train_state_primal.parameters, train_state_primal.state) @@ -112,6 +136,13 @@ function _pre_hook_dual(θ, primal_model, train_state_primal, dual_model, train_ return (dual_hat, gh,) end +""" + _reconcile_alm_primal_state(batch_states::Vector{NamedTuple}) + +Default function that reconciles the state of the primal model after processing a batch of data. +This function computes the maximum violation, mean violations, mean objectives, and total loss +from the batch states. +""" function _reconcile_alm_primal_state(batch_states::Vector{NamedTuple}) max_violation = maximum([s.new_max_violation for s in batch_states]) mean_violations = mean([s.mean_violations for s in batch_states]) @@ -125,11 +156,23 @@ function _reconcile_alm_primal_state(batch_states::Vector{NamedTuple}) ) end +""" + _reconcile_alm_dual_state(batch_states::Vector{NamedTuple}) + +Default function that reconciles the state of the dual model after processing a batch of data. +This function computes the mean dual loss from the batch states. +""" function _reconcile_alm_dual_state(batch_states::Vector{NamedTuple}) dual_loss = mean([s.dual_loss for s in batch_states]) return (dual_loss=dual_loss,) end +""" + _update_ALM_ρ!(hpm::Dict{Symbol, Any}, current_state::NamedTuple) + +Default function to update the hyperparameter ρ in the ALM algorithm. +This function increases ρ by a factor of α if the new maximum violation exceeds τ times the previous maximum violation. +""" function _update_ALM_ρ!(hpm::Dict{Symbol, Any}, current_state::NamedTuple) if current_state.new_max_violation > hpm.τ * hpm.max_violation hpm.ρ = min(hpm.ρmax, hpm.ρ * hpm.α) @@ -138,6 +181,11 @@ function _update_ALM_ρ!(hpm::Dict{Symbol, Any}, current_state::NamedTuple) return end +""" + _default_primal_loop(bm::BatchModel) + +Returns a default `TrainingStepLoop` for the primal model in the L2O-ALM algorithm. +""" function _default_primal_loop(bm::BatchModel) return TrainingStepLoop( LagrangianPrimalLoss(bm), @@ -155,6 +203,11 @@ function _default_primal_loop(bm::BatchModel) ) end +""" + _default_dual_loop() + +Returns a default `TrainingStepLoop` for the dual model in the L2O-ALM algorithm. +""" function _default_dual_loop() return TrainingStepLoop( LagrangianDualLoss(), @@ -168,6 +221,25 @@ function _default_dual_loop() ) end +""" + L2OALM_epoch(bm::BatchModel, primal_model::Lux.Model, train_state_primal::Lux.TrainingState, + dual_model::Lux.Model, train_state_dual::Lux.TrainingState, + training_step_loop_primal::TrainingStepLoop=_default_primal_loop(bm), + training_step_loop_dual::TrainingStepLoop=_default_dual_loop(), + data) + +Runs a single epoch of the L2O-ALM algorithm, training both primal and dual models. + +Arguments: +- `bm`: A `BatchModel` instance that contains the model and batch configuration. +- `primal_model`: The Lux model for the primal problem. +- `train_state_primal`: The training state for the primal model. +- `dual_model`: The Lux model for the dual problem. +- `train_state_dual`: The training state for the dual model. +- `training_step_loop_primal`: The training step loop for the primal model. +- `training_step_loop_dual`: The training step loop for the dual model. +- `data`: The training data, typically a collection of batches. +""" function L2OALM_epoch( bm::BatchModel, primal_model::Lux.Model, From 9444604e421c820233cc62a26e9db8d0ad3148b0 Mon Sep 17 00:00:00 2001 From: Andrew Date: Tue, 22 Jul 2025 14:54:21 -0400 Subject: [PATCH 07/15] add training main loop --- src/L2OALM.jl | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/src/L2OALM.jl b/src/L2OALM.jl index 00f09eb..44601d7 100644 --- a/src/L2OALM.jl +++ b/src/L2OALM.jl @@ -300,5 +300,55 @@ function L2OALM_epoch( return end +""" + L2OALM_train(bm::BatchModel, primal_model::Lux.Model, dual_model::Lux.Model, + train_state_primal::Lux.TrainingState, train_state_dual::Lux.TrainingState, + training_step_loop_primal::TrainingStepLoop=_default_primal_loop(bm), + training_step_loop_dual::TrainingStepLoop=_default_dual_loop(), + stopping_criteria::Vector{Function}=[(iter, primal_model::Lux.Model, dual_model::Lux.Model, + train_state_primal::Lux.TrainingState, train_state_dual::Lux.TrainingState) -> iter >= 100 ? true : false], + data + ) + +Runs the L2O-ALM training algorithm until the stopping criteria are met. + +Arguments: +- `bm`: A `BatchModel` instance that contains the model and batch configuration. +- `primal_model`: The Lux model for the primal problem. +- `dual_model`: The Lux model for the dual problem. +- `train_state_primal`: The training state for the primal model. +- `train_state_dual`: The training state for the dual model. +- `training_step_loop_primal`: The training step loop for the primal model. +- `training_step_loop_dual`: The training step loop for the dual model. +- `stopping_criteria`: A vector of functions that determine when to stop the training loop. +- `data`: The training data, typically a collection of batches. +""" +function L2OALM_train( + bm::BatchModel, + primal_model::Lux.Model, + dual_model::Lux.Model, + train_state_primal::Lux.TrainingState, + train_state_dual::Lux.TrainingState, + training_step_loop_primal::TrainingStepLoop=_default_primal_loop(bm), + training_step_loop_dual::TrainingStepLoop=_default_dual_loop(), + stopping_criteria::Vector{Function}=[(iter, current_state, hpm) -> iter >= 100 ? true : false], + data +) + iter = 1 + while all(stopping_criterion(iter, primal_model, dual_model, train_state_primal, train_state_dual) for stopping_criterion in stopping_criteria) + L2OALM_epoch( + bm, + primal_model, + train_state_primal, + dual_model, + train_state_dual, + training_step_loop_primal, + training_step_loop_dual, + data + ) + iter += 1 + end + return +end end From 58b2fbf33193625602fdec03d75c3f159796e269 Mon Sep 17 00:00:00 2001 From: Andrew Date: Tue, 22 Jul 2025 15:34:10 -0400 Subject: [PATCH 08/15] update tests --- src/L2OALM.jl | 17 +++++---- test/power.jl | 16 ++------ test/runtests.jl | 97 +++++++++++++++++++++++++++++------------------- 3 files changed, 72 insertions(+), 58 deletions(-) diff --git a/src/L2OALM.jl b/src/L2OALM.jl index 44601d7..df6d114 100644 --- a/src/L2OALM.jl +++ b/src/L2OALM.jl @@ -10,7 +10,8 @@ using MLUtils using Optimisers using CUDA -# using OpenCL, pocl_jll, AcceleratedKernels +export LagrangianDualLoss, LagrangianPrimalLoss, TrainingStepLoop, + L2OALM_epoch!, L2OALM_train! """ LagrangianDualLoss(;max_dual=1e6) @@ -23,7 +24,7 @@ Target dual variables are clipped from zero to `max_dual`. Keywords: - `max_dual`: Maximum value for the target dual variables. """ -function LagrangianDualLoss(;max_dual=1e6) +function LagrangianDualLoss(n_bus::Int; max_dual=1e6) return (dual_model, ps_dual, st_dual, data) -> begin x, hpm, dual_hat_k, gh = data ρ = hpm.ρ @@ -240,14 +241,14 @@ Arguments: - `training_step_loop_dual`: The training step loop for the dual model. - `data`: The training data, typically a collection of batches. """ -function L2OALM_epoch( +function L2OALM_epoch!( bm::BatchModel, primal_model::Lux.Model, train_state_primal::Lux.TrainingState, dual_model::Lux.Model, train_state_dual::Lux.TrainingState, - training_step_loop_primal::TrainingStepLoop=_default_primal_loop(bm), - training_step_loop_dual::TrainingStepLoop=_default_dual_loop(), + training_step_loop_primal::TrainingStepLoop, + training_step_loop_dual::TrainingStepLoop, data ) iter_primal = 1 @@ -323,7 +324,7 @@ Arguments: - `stopping_criteria`: A vector of functions that determine when to stop the training loop. - `data`: The training data, typically a collection of batches. """ -function L2OALM_train( +function L2OALM_train!( bm::BatchModel, primal_model::Lux.Model, dual_model::Lux.Model, @@ -331,12 +332,12 @@ function L2OALM_train( train_state_dual::Lux.TrainingState, training_step_loop_primal::TrainingStepLoop=_default_primal_loop(bm), training_step_loop_dual::TrainingStepLoop=_default_dual_loop(), - stopping_criteria::Vector{Function}=[(iter, current_state, hpm) -> iter >= 100 ? true : false], + stopping_criteria::Vector{Function}=[(iter, primal_model, dual_model, train_state_primal, train_state_dual) -> iter >= 100 ? true : false], data ) iter = 1 while all(stopping_criterion(iter, primal_model, dual_model, train_state_primal, train_state_dual) for stopping_criterion in stopping_criteria) - L2OALM_epoch( + L2OALM_epoch!( bm, primal_model, train_state_primal, diff --git a/test/power.jl b/test/power.jl index 545e707..0e021df 100644 --- a/test/power.jl +++ b/test/power.jl @@ -116,8 +116,7 @@ function create_parametric_ac_power_model(filename::String = "pglib_opf_case14_i pg = variable(c, length(data.gen); lvar = data.pmin, uvar = data.pmax) qg = variable(c, length(data.gen); lvar = data.qmin, uvar = data.qmax) - @allowscalar pd = parameter(c, [b.pd for b in data.bus]) - @allowscalar qd = parameter(c, [b.qd for b in data.bus]) + @allowscalar load_multiplier = parameter(c, [1.0 for b in data.bus]) p = variable(c, length(data.arc); lvar = -data.rate_a, uvar = data.rate_a) q = variable(c, length(data.arc); lvar = -data.rate_a, uvar = data.rate_a) @@ -179,8 +178,8 @@ function create_parametric_ac_power_model(filename::String = "pglib_opf_case14_i ) # Power balance at each bus ----------------------------------------------- - load_balance_p = constraint(c, pd[b.i] + b.gs * vm[b.i]^2 for b in data.bus) - load_balance_q = constraint(c, qd[b.i] - b.bs * vm[b.i]^2 for b in data.bus) + load_balance_p = constraint(c, b.pd * load_multiplier[b.i] + b.gs * vm[b.i]^2 for b in data.bus) + load_balance_q = constraint(c, b.qd * load_multiplier[b.i] - b.bs * vm[b.i]^2 for b in data.bus) # Map arc & generator variables into the bus balance equations constraint!(c, load_balance_p, a.bus => p[a.i] for a in data.arc) @@ -188,12 +187,5 @@ function create_parametric_ac_power_model(filename::String = "pglib_opf_case14_i constraint!(c, load_balance_p, g.bus => -pg[g.i] for g in data.gen) constraint!(c, load_balance_q, g.bus => -qg[g.i] for g in data.gen) - return ExaModel(c; prod = prod) + return ExaModel(c; prod = prod), length(data.bus), length(data.gen), length(data.arc) end - -function create_power_models(backend = OpenCLBackend(), T=Float64) - models = ExaModel[] - push!(models, create_ac_power_model("pglib_opf_case14_ieee.m"; backend = backend)) - names = ["AC-OPF – IEEE-14"] - return models, names -end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 816669b..2089435 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,7 @@ using L2OALM using Test +using Lux + @testset "L2OALM.jl" begin function feed_forward_builder( @@ -29,68 +31,87 @@ using Test return Chain(dense_layers...) end - function test_penalty_training(; filename="pglib_opf_case14_ieee.m", dev_gpu = gpu_device(), + function test_alm_training(; filename="pglib_opf_case14_ieee.m", dev_gpu = gpu_device(), backend=CPU(), batch_size = 32, dataset_size = 3200, rng = Random.default_rng(), T = Float64, ) - model = create_parametric_ac_power_model(filename; backend = backend, T=T) - bm = BNK.BatchModel(model, batch_size, config=BNK.BatchModelConfig(:full)) - bm_all = BNK.BatchModel(model, dataset_size, config=BNK.BatchModelConfig(:full)) - - function PenaltyLoss(model, ps, st, Θ) - X̂ , st_new = model(Θ, ps, st) - - obj = BNK.objective!(bm, X̂, Θ) - Vc, Vb = BNK.all_violations!(bm, X̂, Θ) - - return sum(obj) + 1000 * sum(Vc) + 1000 * sum(Vb), st_new, (;obj=sum(obj), Vc=sum(Vc), Vb=sum(Vb)) - end + model, nbus, ngen, blines = create_parametric_ac_power_model(filename; backend = backend, T=T) + bm_train = BNK.BatchModel(model, batch_size, config=BNK.BatchModelConfig(:full)) + bm_test = BNK.BatchModel(model, dataset_size, config=BNK.BatchModelConfig(:full)) nvar = model.meta.nvar ncon = model.meta.ncon nθ = length(model.θ) Θ_train = randn(T, nθ, dataset_size) |> dev_gpu + Θ_test = randn(T, nθ, dataset_size) |> dev_gpu - lux_model = feed_forward_builder(nθ, nvar, [320, 320]) - - ps_model, st_model = Lux.setup(rng, lux_model) - ps_model = ps_model |> dev_gpu - st_model = st_model |> dev_gpu + primal_model = feed_forward_builder(nθ, nvar, [320, 320]) + ps_primal, st_primal = Lux.setup(rng, primal_model) + ps_primal = ps_primal |> dev_gpu + st_primal = st_primal |> dev_gpu + + dual_model = feed_forward_builder(nθ, ncon, [320, 320]) + ps_dual, st_dual = Lux.setup(rng, dual_model) + ps_dual = ps_dual |> dev_gpu + st_dual = st_dual |> dev_gpu - X̂ , _ = lux_model(Θ_train, ps_model, st_model) + X̂ , _ = primal_model(Θ_test, ps_primal, st_primal) - y = BNK.objective!(bm_all, X̂, Θ_train) + y = BNK.objective!(bm_test, X̂, Θ_test) @test length(y) == dataset_size - Vc, Vb = BNK.all_violations!(bm_all, X̂, Θ_train) + Vc, Vb = BNK.all_violations!(bm_test, X̂, Θ_test) @test size(Vc) == (ncon, dataset_size) @test size(Vb) == (nvar, dataset_size) - lagrangian_prev = sum(y) + 1000 * sum(Vc) + 1000 * sum(Vb) + # lagrangian_prev = sum(y) + 1000 * sum(Vc) + 1000 * sum(Vb) - train_state = Training.TrainState(lux_model, ps_model, st_model, Optimisers.Adam(1e-5)) + train_state_primal = Training.TrainState(primal_model, ps_primal, st_primal, Optimisers.Adam(1e-5)) + train_state_dual = Training.TrainState(dual_model, ps_dual, st_dual, Optimisers.Adam(1e-5)) data = DataLoader((Θ_train); batchsize=batch_size, shuffle=true) .|> dev_gpu - for (Θ) in data - _, loss_val, stats, train_state = Training.single_train_step!( - AutoZygote(), # AD backend - PenaltyLoss, - (Θ), # data - train_state + + function validation_testset( + iter, primal_model, dual_model, train_state_primal, train_state_dual, + ) + X̂_test , _ = primal_model(Θ_test, train_state_primal.parameters, train_state_primal.state) + objs_test = BNK.objective!(bm_test, X̂_test, Θ_test) + Vc_test, Vb_test = BNK.all_violations!(bm_test, X̂_test, Θ_test) + gh_test = BNK.constraints!(bm_test, X̂_test, Θ_test) + λ_test, _ = dual_model(Θ_test, train_state_dual.parameters, train_state_dual.state) + # Separate bound and equality constraints + gh_bound = gh[1:end-n_bus*2,:] + gh_equal = gh[end-n_bus*2+1:end,:] + dual_hat_bound = dual_hat_k[1:end-n_bus*2,:] + dual_hat_equal = dual_hat_k[end-n_bus*2+1:end,:] + + # Target for dual variables + dual_target = vcat( + min.(max.(dual_hat_bound + ρ .* gh_bound, 0), max_dual), + min.(dual_hat_equal + ρ .* gh_equal, max_dual) ) + + loss = mean((dual_hat .- dual_target).^2) + + @info "Validation Testset: Iteration $iter" mean(objs_test) mean(Vc_test) mean(Vb_test) + return iter >= 100 ? true : false end - - X̂ , st_model = lux_model(Θ_train, ps_model, st_model) - - y = BNK.objective!(bm_all, X̂, Θ_train) - Vc, Vb = BNK.all_violations!(bm_all, X̂, Θ_train) - lagrangian_new = sum(y) + 1000 * sum(Vc) + 1000 * sum(Vb) - - @test lagrangian_new < lagrangian_prev + + L2OALM_train!( + bm_train, + primal_model, + dual_model, + train_state_primal, + train_state_dual, + training_step_loop_primal, + training_step_loop_dual, + stopping_criteria::Vector{Function}=[(iter, current_state, hpm) -> iter >= 100 ? true : false], + data + ) end @testset "Penalty Training" begin @@ -100,6 +121,6 @@ using Test CPU(), cpu_device() end - test_penalty_training(; filename="pglib_opf_case14_ieee.m", dev_gpu = dev, backend=backend, T=Float32) + test_alm_training(; filename="pglib_opf_case14_ieee.m", dev_gpu = dev, backend=backend, T=Float32) end end From b84c8de00e8535c7560f8d1a13e2304eb78529e1 Mon Sep 17 00:00:00 2001 From: Andrew Date: Tue, 22 Jul 2025 16:45:10 -0400 Subject: [PATCH 09/15] update tests --- Project.toml | 4 +++- src/L2OALM.jl | 37 +++++++++++++++++++------------------ test/runtests.jl | 23 ++++++++++++++--------- 3 files changed, 36 insertions(+), 28 deletions(-) diff --git a/Project.toml b/Project.toml index 4a83842..c1a6422 100644 --- a/Project.toml +++ b/Project.toml @@ -16,7 +16,9 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" julia = "1.6.7" [extras] +PowerModels = "c36e90e8-916a-50a6-bd94-075b64ef4655" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +PGLib = "07a8691f-3d11-4330-951b-3c50f98338be" [targets] -test = ["Test"] +test = ["Test", "PowerModels", "PGLib"] diff --git a/src/L2OALM.jl b/src/L2OALM.jl index df6d114..df29daf 100644 --- a/src/L2OALM.jl +++ b/src/L2OALM.jl @@ -24,7 +24,7 @@ Target dual variables are clipped from zero to `max_dual`. Keywords: - `max_dual`: Maximum value for the target dual variables. """ -function LagrangianDualLoss(n_bus::Int; max_dual=1e6) +function LagrangianDualLoss(num_equal::Int; max_dual=1e6) return (dual_model, ps_dual, st_dual, data) -> begin x, hpm, dual_hat_k, gh = data ρ = hpm.ρ @@ -32,10 +32,10 @@ function LagrangianDualLoss(n_bus::Int; max_dual=1e6) dual_hat, st_dual_new = dual_model(x, ps_dual, st_dual) # Separate bound and equality constraints - gh_bound = gh[1:end-n_bus*2,:] - gh_equal = gh[end-n_bus*2+1:end,:] - dual_hat_bound = dual_hat_k[1:end-n_bus*2,:] - dual_hat_equal = dual_hat_k[end-n_bus*2+1:end,:] + gh_bound = gh[1:end-num_equal,:] + gh_equal = gh[end-num_equal+1:end,:] + dual_hat_bound = dual_hat_k[1:end-num_equal,:] + dual_hat_equal = dual_hat_k[end-num_equal+1:end,:] # Target for dual variables dual_target = vcat( @@ -174,11 +174,12 @@ end Default function to update the hyperparameter ρ in the ALM algorithm. This function increases ρ by a factor of α if the new maximum violation exceeds τ times the previous maximum violation. """ -function _update_ALM_ρ!(hpm::Dict{Symbol, Any}, current_state::NamedTuple) - if current_state.new_max_violation > hpm.τ * hpm.max_violation - hpm.ρ = min(hpm.ρmax, hpm.ρ * hpm.α) +function _update_ALM_ρ!(hpm_primal::Dict{Symbol, Any}, hpm_dual::Dict{Symbol, Any}, current_state::NamedTuple) + if current_state.new_max_violation > hpm_primal.τ * hpm_primal.max_violation + hpm_primal.ρ = min(hpm_primal.ρmax, hpm_primal.ρ * hpm_primal.α) + hpm_dual.ρ = hpm_primal.ρ # Ensure dual model uses the same ρ end - hpm.max_violation = current_state.new_max_violation + hpm_primal.max_violation = current_state.new_max_violation return end @@ -209,12 +210,13 @@ end Returns a default `TrainingStepLoop` for the dual model in the L2O-ALM algorithm. """ -function _default_dual_loop() +function _default_dual_loop(num_equal::Int) return TrainingStepLoop( - LagrangianDualLoss(), + LagrangianDualLoss(num_equal), [(iter, current_state, hpm) -> iter >= 100 ? true : false], Dict{Symbol, Any}( :max_dual => 1e6, + :ρ => 1.0, ), [], _reconcile_alm_dual_state, @@ -242,7 +244,6 @@ Arguments: - `data`: The training data, typically a collection of batches. """ function L2OALM_epoch!( - bm::BatchModel, primal_model::Lux.Model, train_state_primal::Lux.TrainingState, dual_model::Lux.Model, @@ -275,7 +276,7 @@ function L2OALM_epoch!( iter_primal += 1 end for fn in training_step_loop_primal.parameter_update_fns - fn(training_step_loop_primal.hyperparameters, current_state_primal) + fn(training_step_loop_primal.hyperparameters, training_step_loop_dual.hyperparameters, current_state_primal) end # dual loop @@ -296,7 +297,7 @@ function L2OALM_epoch!( iter_dual += 1 end for fn in training_step_loop_dual.parameter_update_fns - fn(training_step_loop_dual.hyperparameters, current_state_dual) + fn(training_step_loop_primal.hyperparameters, training_step_loop_dual.hyperparameters, current_state_dual) end return end @@ -326,19 +327,19 @@ Arguments: """ function L2OALM_train!( bm::BatchModel, + num_equal::Int, primal_model::Lux.Model, dual_model::Lux.Model, train_state_primal::Lux.TrainingState, train_state_dual::Lux.TrainingState, training_step_loop_primal::TrainingStepLoop=_default_primal_loop(bm), - training_step_loop_dual::TrainingStepLoop=_default_dual_loop(), - stopping_criteria::Vector{Function}=[(iter, primal_model, dual_model, train_state_primal, train_state_dual) -> iter >= 100 ? true : false], + training_step_loop_dual::TrainingStepLoop=_default_dual_loop(num_equal), + stopping_criteria::Vector{Function}=[(iter, primal_model, dual_model, tr_st_primal, tr_st_dual, hpm_primal) -> iter >= 100 ? true : false], data ) iter = 1 - while all(stopping_criterion(iter, primal_model, dual_model, train_state_primal, train_state_dual) for stopping_criterion in stopping_criteria) + while all(stopping_criterion(iter, primal_model, dual_model, train_state_primal, train_state_dual, training_step_loop_primal.hyperparameters, training_step_loop_dual.hyperparameters) for stopping_criterion in stopping_criteria) L2OALM_epoch!( - bm, primal_model, train_state_primal, dual_model, diff --git a/test/runtests.jl b/test/runtests.jl index 2089435..96b428a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,9 @@ using L2OALM using Test using Lux +using PowerModels +PowerModels.silence() +using PGLib @testset "L2OALM.jl" begin function feed_forward_builder( @@ -45,6 +48,7 @@ using Lux nvar = model.meta.nvar ncon = model.meta.ncon nθ = length(model.θ) + num_equal = nbus * 2 Θ_train = randn(T, nθ, dataset_size) |> dev_gpu Θ_test = randn(T, nθ, dataset_size) |> dev_gpu @@ -76,18 +80,19 @@ using Lux data = DataLoader((Θ_train); batchsize=batch_size, shuffle=true) .|> dev_gpu function validation_testset( - iter, primal_model, dual_model, train_state_primal, train_state_dual, + iter, primal_model, dual_model, train_state_primal, train_state_dual, hpm_primal, hpm_dual; max_dual=1e6 ) + ρ = hpm_primal.ρ X̂_test , _ = primal_model(Θ_test, train_state_primal.parameters, train_state_primal.state) objs_test = BNK.objective!(bm_test, X̂_test, Θ_test) Vc_test, Vb_test = BNK.all_violations!(bm_test, X̂_test, Θ_test) gh_test = BNK.constraints!(bm_test, X̂_test, Θ_test) - λ_test, _ = dual_model(Θ_test, train_state_dual.parameters, train_state_dual.state) + dual_hat, _ = dual_model(Θ_test, train_state_dual.parameters, train_state_dual.state) # Separate bound and equality constraints - gh_bound = gh[1:end-n_bus*2,:] - gh_equal = gh[end-n_bus*2+1:end,:] - dual_hat_bound = dual_hat_k[1:end-n_bus*2,:] - dual_hat_equal = dual_hat_k[end-n_bus*2+1:end,:] + gh_bound = gh_test[1:end-num_equal,:] + gh_equal = gh_test[end-num_equal+1:end,:] + dual_hat_bound = dual_hat[1:end-num_equal,:] + dual_hat_equal = dual_hat[end-num_equal+1:end,:] # Target for dual variables dual_target = vcat( @@ -95,9 +100,9 @@ using Lux min.(dual_hat_equal + ρ .* gh_equal, max_dual) ) - loss = mean((dual_hat .- dual_target).^2) + dual_loss = mean((dual_hat .- dual_target).^2) - @info "Validation Testset: Iteration $iter" mean(objs_test) mean(Vc_test) mean(Vb_test) + @info "Validation Testset: Iteration $iter" mean(objs_test) mean(Vc_test) mean(Vb_test) dual_loss return iter >= 100 ? true : false end @@ -109,7 +114,7 @@ using Lux train_state_dual, training_step_loop_primal, training_step_loop_dual, - stopping_criteria::Vector{Function}=[(iter, current_state, hpm) -> iter >= 100 ? true : false], + [validation_testset], data ) end From e3cd8a79fce006d10b5287ba5cea41e9485deb0a Mon Sep 17 00:00:00 2001 From: Andrew Date: Tue, 22 Jul 2025 17:16:15 -0400 Subject: [PATCH 10/15] update deps --- Project.toml | 8 +++++--- src/L2OALM.jl | 36 ++++++++++++++++++------------------ test/runtests.jl | 10 +++++++++- 3 files changed, 32 insertions(+), 22 deletions(-) diff --git a/Project.toml b/Project.toml index c1a6422..bf56cfd 100644 --- a/Project.toml +++ b/Project.toml @@ -9,16 +9,18 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ExaModels = "1037b233-b668-4ce9-9b63-f9f681f55dd2" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" -MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" [compat] julia = "1.6.7" [extras] +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +PGLib = "07a8691f-3d11-4330-951b-3c50f98338be" PowerModels = "c36e90e8-916a-50a6-bd94-075b64ef4655" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -PGLib = "07a8691f-3d11-4330-951b-3c50f98338be" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" [targets] -test = ["Test", "PowerModels", "PGLib"] +test = ["Test", "PowerModels", "PGLib", "Random", "MLUtils", "KernelAbstractions"] diff --git a/src/L2OALM.jl b/src/L2OALM.jl index df29daf..c9ed03a 100644 --- a/src/L2OALM.jl +++ b/src/L2OALM.jl @@ -6,8 +6,6 @@ using ExaModels using Lux using LuxCUDA using Lux.Training -using MLUtils -using Optimisers using CUDA export LagrangianDualLoss, LagrangianPrimalLoss, TrainingStepLoop, @@ -225,8 +223,8 @@ function _default_dual_loop(num_equal::Int) end """ - L2OALM_epoch(bm::BatchModel, primal_model::Lux.Model, train_state_primal::Lux.TrainingState, - dual_model::Lux.Model, train_state_dual::Lux.TrainingState, + L2OALM_epoch(bm::BatchModel, primal_model::Lux.Chain, train_state_primal::Lux.Training.TrainState, + dual_model::Lux.Chain, train_state_dual::Lux.Training.TrainState, training_step_loop_primal::TrainingStepLoop=_default_primal_loop(bm), training_step_loop_dual::TrainingStepLoop=_default_dual_loop(), data) @@ -244,10 +242,10 @@ Arguments: - `data`: The training data, typically a collection of batches. """ function L2OALM_epoch!( - primal_model::Lux.Model, - train_state_primal::Lux.TrainingState, - dual_model::Lux.Model, - train_state_dual::Lux.TrainingState, + primal_model::Lux.Chain, + train_state_primal::Lux.Training.TrainState, + dual_model::Lux.Chain, + train_state_dual::Lux.Training.TrainState, training_step_loop_primal::TrainingStepLoop, training_step_loop_dual::TrainingStepLoop, data @@ -303,12 +301,13 @@ function L2OALM_epoch!( end """ - L2OALM_train(bm::BatchModel, primal_model::Lux.Model, dual_model::Lux.Model, - train_state_primal::Lux.TrainingState, train_state_dual::Lux.TrainingState, + L2OALM_train(bm::BatchModel, num_equal::Int, + primal_model::Lux.Chain, dual_model::Lux.Chain, + train_state_primal::Lux.Training.TrainState, train_state_dual::Lux.Training.TrainState, training_step_loop_primal::TrainingStepLoop=_default_primal_loop(bm), training_step_loop_dual::TrainingStepLoop=_default_dual_loop(), - stopping_criteria::Vector{Function}=[(iter, primal_model::Lux.Model, dual_model::Lux.Model, - train_state_primal::Lux.TrainingState, train_state_dual::Lux.TrainingState) -> iter >= 100 ? true : false], + stopping_criteria::Vector{Function}=[(iter, primal_model::Lux.Chain, dual_model::Lux.Chain, + train_state_primal::Lux.Training.TrainState, train_state_dual::Lux.Training.TrainState) -> iter >= 100 ? true : false], data ) @@ -316,26 +315,27 @@ Runs the L2O-ALM training algorithm until the stopping criteria are met. Arguments: - `bm`: A `BatchModel` instance that contains the model and batch configuration. +- `num_equal`: The number of equality constraints in the problem, used for dual loss calculation. - `primal_model`: The Lux model for the primal problem. - `dual_model`: The Lux model for the dual problem. - `train_state_primal`: The training state for the primal model. - `train_state_dual`: The training state for the dual model. +- `data`: The training data, typically a collection of batches. - `training_step_loop_primal`: The training step loop for the primal model. - `training_step_loop_dual`: The training step loop for the dual model. - `stopping_criteria`: A vector of functions that determine when to stop the training loop. -- `data`: The training data, typically a collection of batches. """ function L2OALM_train!( bm::BatchModel, num_equal::Int, - primal_model::Lux.Model, - dual_model::Lux.Model, - train_state_primal::Lux.TrainingState, - train_state_dual::Lux.TrainingState, + primal_model::Lux.Chain, + dual_model::Lux.Chain, + train_state_primal::Lux.Training.TrainState, + train_state_dual::Lux.Training.TrainState, + data, training_step_loop_primal::TrainingStepLoop=_default_primal_loop(bm), training_step_loop_dual::TrainingStepLoop=_default_dual_loop(num_equal), stopping_criteria::Vector{Function}=[(iter, primal_model, dual_model, tr_st_primal, tr_st_dual, hpm_primal) -> iter >= 100 ? true : false], - data ) iter = 1 while all(stopping_criterion(iter, primal_model, dual_model, train_state_primal, train_state_dual, training_step_loop_primal.hyperparameters, training_step_loop_dual.hyperparameters) for stopping_criterion in stopping_criteria) diff --git a/test/runtests.jl b/test/runtests.jl index 96b428a..9c0e18c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,11 +1,19 @@ using L2OALM using Test using Lux +using Optimisers +using MLUtils +using KernelAbstractions +using ExaModels +using BatchNLPKernels +using CUDA using PowerModels PowerModels.silence() using PGLib +using Random + @testset "L2OALM.jl" begin function feed_forward_builder( num_p::Integer, @@ -112,10 +120,10 @@ using PGLib dual_model, train_state_primal, train_state_dual, + data, training_step_loop_primal, training_step_loop_dual, [validation_testset], - data ) end From a0f3354dd12327ee6f0a00bb0f8a7d980b46a302 Mon Sep 17 00:00:00 2001 From: Andrew Date: Tue, 22 Jul 2025 17:40:11 -0400 Subject: [PATCH 11/15] fix typos --- Project.toml | 5 +++-- src/L2OALM.jl | 22 +++++++++++----------- test/runtests.jl | 17 +++++++++++------ 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/Project.toml b/Project.toml index bf56cfd..97d8fe1 100644 --- a/Project.toml +++ b/Project.toml @@ -15,12 +15,13 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" julia = "1.6.7" [extras] +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" PGLib = "07a8691f-3d11-4330-951b-3c50f98338be" PowerModels = "c36e90e8-916a-50a6-bd94-075b64ef4655" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" [targets] -test = ["Test", "PowerModels", "PGLib", "Random", "MLUtils", "KernelAbstractions"] +test = ["Test", "PowerModels", "PGLib", "Random", "MLUtils", "KernelAbstractions", "GPUArraysCore"] diff --git a/src/L2OALM.jl b/src/L2OALM.jl index c9ed03a..5a64cc1 100644 --- a/src/L2OALM.jl +++ b/src/L2OALM.jl @@ -25,7 +25,7 @@ Keywords: function LagrangianDualLoss(num_equal::Int; max_dual=1e6) return (dual_model, ps_dual, st_dual, data) -> begin x, hpm, dual_hat_k, gh = data - ρ = hpm.ρ + ρ = hpm[:ρ] # Get current dual predictions dual_hat, st_dual_new = dual_model(x, ps_dual, st_dual) @@ -58,7 +58,7 @@ Arguments: function LagrangianPrimalLoss(bm::BatchModel) return (model, ps, st, data) -> begin Θ, hpm, dual_hat = data - ρ = hpm.ρ + ρ = hpm[:ρ] num_s = size(Θ, 2) # Forward pass for prediction @@ -113,7 +113,7 @@ This function performs a forward pass through the dual model to obtain the dual """ function _pre_hook_primal(θ, primal_model, train_state_primal, dual_model, train_state_dual, bm) # Forward pass for dual model - dual_hat_k, _ = dual_model(θ, train_state_dual.parameters, train_state_dual.state) + dual_hat_k, _ = dual_model(θ, train_state_dual.parameters, train_state_dual.states) return (dual_hat_k,) end @@ -126,11 +126,11 @@ This function performs a forward pass through the primal model to obtain the pre """ function _pre_hook_dual(θ, primal_model, train_state_primal, dual_model, train_state_dual, bm) # # Forward pass for primal model - X̂, _ = primal_model(θ, train_state_primal.parameters, train_state_primal.state) + X̂, _ = primal_model(θ, train_state_primal.parameters, train_state_primal.states) gh = constraints!(bm, X̂, Θ) # Forward pass for dual model - dual_hat, _ = dual_model(θ, train_state_dual.parameters, train_state_dual.state) + dual_hat, _ = dual_model(θ, train_state_dual.parameters, train_state_dual.states) return (dual_hat, gh,) end @@ -174,8 +174,8 @@ This function increases ρ by a factor of α if the new maximum violation exceed """ function _update_ALM_ρ!(hpm_primal::Dict{Symbol, Any}, hpm_dual::Dict{Symbol, Any}, current_state::NamedTuple) if current_state.new_max_violation > hpm_primal.τ * hpm_primal.max_violation - hpm_primal.ρ = min(hpm_primal.ρmax, hpm_primal.ρ * hpm_primal.α) - hpm_dual.ρ = hpm_primal.ρ # Ensure dual model uses the same ρ + hpm_primal[:ρ] = min(hpm_primal[:ρmax], hpm_primal[:ρ] * hpm_primal[:α]) + hpm_dual[:ρ] = hpm_primal[:ρ] # Ensure dual model uses the same ρ end hpm_primal.max_violation = current_state.new_max_violation return @@ -195,7 +195,7 @@ function _default_primal_loop(bm::BatchModel) :ρmax => 1e6, :τ => 0.8, :α => 10.0, - max_violation => 0.0, + :max_violation => 0.0, ), [_update_ALM_ρ!], _reconcile_alm_primal_state, @@ -332,11 +332,11 @@ function L2OALM_train!( dual_model::Lux.Chain, train_state_primal::Lux.Training.TrainState, train_state_dual::Lux.Training.TrainState, - data, + data; training_step_loop_primal::TrainingStepLoop=_default_primal_loop(bm), training_step_loop_dual::TrainingStepLoop=_default_dual_loop(num_equal), - stopping_criteria::Vector{Function}=[(iter, primal_model, dual_model, tr_st_primal, tr_st_dual, hpm_primal) -> iter >= 100 ? true : false], -) + stopping_criteria::Vector{F}=[(iter, primal_model, dual_model, tr_st_primal, tr_st_dual, hpm_primal) -> iter >= 100 ? true : false], +) where F<:Function iter = 1 while all(stopping_criterion(iter, primal_model, dual_model, train_state_primal, train_state_dual, training_step_loop_primal.hyperparameters, training_step_loop_dual.hyperparameters) for stopping_criterion in stopping_criteria) L2OALM_epoch!( diff --git a/test/runtests.jl b/test/runtests.jl index 9c0e18c..0e3e802 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,6 +14,12 @@ using PGLib using Random +import GPUArraysCore: @allowscalar + +const BNK = BatchNLPKernels + +include("power.jl") + @testset "L2OALM.jl" begin function feed_forward_builder( num_p::Integer, @@ -90,12 +96,12 @@ using Random function validation_testset( iter, primal_model, dual_model, train_state_primal, train_state_dual, hpm_primal, hpm_dual; max_dual=1e6 ) - ρ = hpm_primal.ρ - X̂_test , _ = primal_model(Θ_test, train_state_primal.parameters, train_state_primal.state) + ρ = hpm_primal[:ρ] + X̂_test , _ = primal_model(Θ_test, train_state_primal.parameters, train_state_primal.states) objs_test = BNK.objective!(bm_test, X̂_test, Θ_test) Vc_test, Vb_test = BNK.all_violations!(bm_test, X̂_test, Θ_test) gh_test = BNK.constraints!(bm_test, X̂_test, Θ_test) - dual_hat, _ = dual_model(Θ_test, train_state_dual.parameters, train_state_dual.state) + dual_hat, _ = dual_model(Θ_test, train_state_dual.parameters, train_state_dual.states) # Separate bound and equality constraints gh_bound = gh_test[1:end-num_equal,:] gh_equal = gh_test[end-num_equal+1:end,:] @@ -116,14 +122,13 @@ using Random L2OALM_train!( bm_train, + num_equal, primal_model, dual_model, train_state_primal, train_state_dual, data, - training_step_loop_primal, - training_step_loop_dual, - [validation_testset], + stopping_criteria=[validation_testset], ) end From 1b34f0f5590ca461de439740dc63bcc0b605a2d6 Mon Sep 17 00:00:00 2001 From: Andrew Date: Tue, 22 Jul 2025 17:42:21 -0400 Subject: [PATCH 12/15] running CPU --- Project.toml | 3 ++- src/L2OALM.jl | 1 + test/runtests.jl | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 97d8fe1..d05df0f 100644 --- a/Project.toml +++ b/Project.toml @@ -10,18 +10,19 @@ ExaModels = "1037b233-b668-4ce9-9b63-f9f681f55dd2" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] julia = "1.6.7" [extras] +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" PGLib = "07a8691f-3d11-4330-951b-3c50f98338be" PowerModels = "c36e90e8-916a-50a6-bd94-075b64ef4655" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" [targets] test = ["Test", "PowerModels", "PGLib", "Random", "MLUtils", "KernelAbstractions", "GPUArraysCore"] diff --git a/src/L2OALM.jl b/src/L2OALM.jl index 5a64cc1..ee2c249 100644 --- a/src/L2OALM.jl +++ b/src/L2OALM.jl @@ -7,6 +7,7 @@ using Lux using LuxCUDA using Lux.Training using CUDA +using Statistics export LagrangianDualLoss, LagrangianPrimalLoss, TrainingStepLoop, L2OALM_epoch!, L2OALM_train! diff --git a/test/runtests.jl b/test/runtests.jl index 0e3e802..814ea0b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ PowerModels.silence() using PGLib using Random +using Statistics import GPUArraysCore: @allowscalar From 6ad1d317962d8aa2c1dd77a2eb663b82f8a3a818 Mon Sep 17 00:00:00 2001 From: Andrew Date: Tue, 22 Jul 2025 17:43:54 -0400 Subject: [PATCH 13/15] format --- docs/make.jl | 25 +++---- src/L2OALM.jl | 183 ++++++++++++++++++++++++++++++++--------------- test/power.jl | 163 +++++++++++++++++++++++------------------ test/runtests.jl | 102 ++++++++++++++++---------- 4 files changed, 291 insertions(+), 182 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index de4ea88..99f1798 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,23 +1,18 @@ using L2OALM using Documenter -DocMeta.setdocmeta!(L2OALM, :DocTestSetup, :(using L2OALM); recursive=true) +DocMeta.setdocmeta!(L2OALM, :DocTestSetup, :(using L2OALM); recursive = true) makedocs(; - modules=[L2OALM], - authors="Andrew and contributors", - sitename="L2OALM.jl", - format=Documenter.HTML(; - canonical="https://LearningToOptimize.github.io/L2OALM.jl", - edit_link="main", - assets=String[], + modules = [L2OALM], + authors = "Andrew and contributors", + sitename = "L2OALM.jl", + format = Documenter.HTML(; + canonical = "https://LearningToOptimize.github.io/L2OALM.jl", + edit_link = "main", + assets = String[], ), - pages=[ - "Home" => "index.md", - ], + pages = ["Home" => "index.md"], ) -deploydocs(; - repo="github.com/LearningToOptimize/L2OALM.jl", - devbranch="main", -) +deploydocs(; repo = "github.com/LearningToOptimize/L2OALM.jl", devbranch = "main") diff --git a/src/L2OALM.jl b/src/L2OALM.jl index ee2c249..4e1e707 100644 --- a/src/L2OALM.jl +++ b/src/L2OALM.jl @@ -9,8 +9,8 @@ using Lux.Training using CUDA using Statistics -export LagrangianDualLoss, LagrangianPrimalLoss, TrainingStepLoop, - L2OALM_epoch!, L2OALM_train! +export LagrangianDualLoss, + LagrangianPrimalLoss, TrainingStepLoop, L2OALM_epoch!, L2OALM_train! """ LagrangianDualLoss(;max_dual=1e6) @@ -23,27 +23,27 @@ Target dual variables are clipped from zero to `max_dual`. Keywords: - `max_dual`: Maximum value for the target dual variables. """ -function LagrangianDualLoss(num_equal::Int; max_dual=1e6) +function LagrangianDualLoss(num_equal::Int; max_dual = 1e6) return (dual_model, ps_dual, st_dual, data) -> begin x, hpm, dual_hat_k, gh = data ρ = hpm[:ρ] # Get current dual predictions dual_hat, st_dual_new = dual_model(x, ps_dual, st_dual) - + # Separate bound and equality constraints - gh_bound = gh[1:end-num_equal,:] - gh_equal = gh[end-num_equal+1:end,:] - dual_hat_bound = dual_hat_k[1:end-num_equal,:] - dual_hat_equal = dual_hat_k[end-num_equal+1:end,:] - + gh_bound = gh[1:end-num_equal, :] + gh_equal = gh[end-num_equal+1:end, :] + dual_hat_bound = dual_hat_k[1:end-num_equal, :] + dual_hat_equal = dual_hat_k[end-num_equal+1:end, :] + # Target for dual variables dual_target = vcat( min.(max.(dual_hat_bound + ρ .* gh_bound, 0), max_dual), - min.(dual_hat_equal + ρ .* gh_equal, max_dual) + min.(dual_hat_equal + ρ .* gh_equal, max_dual), ) - - loss = mean((dual_hat .- dual_target).^2) - return loss, st_dual_new, (dual_loss=loss,) + + loss = mean((dual_hat .- dual_target) .^ 2) + return loss, st_dual_new, (dual_loss = loss,) end end @@ -56,7 +56,7 @@ from current dual predictions `dual_hat` for the batch model `bm` under paramete Arguments: - `bm`: A `BatchModel` instance that contains the model and batch configuration. """ -function LagrangianPrimalLoss(bm::BatchModel) +function LagrangianPrimalLoss(bm::BatchModel) return (model, ps, st, data) -> begin Θ, hpm, dual_hat = data ρ = hpm[:ρ] @@ -64,23 +64,23 @@ function LagrangianPrimalLoss(bm::BatchModel) # Forward pass for prediction X̂, st_new = model(Θ, ps, st) - + # Calculate violations and objectives objs = BNK.objective!(bm, X̂, Θ) # gh = constraints!(bm, X̂, Θ) Vc, Vb = BNK.all_violations!(bm, X̂, Θ) V = vcat(Vb, Vc) total_loss = ( - sum(abs.(dual_hat .* V)) / num_s + - ρ / 2 * sum((V).^2) / num_s + - mean(objs) + sum(abs.(dual_hat .* V)) / num_s + ρ / 2 * sum((V) .^ 2) / num_s + mean(objs) ) - return total_loss, st_new, ( - total_loss=total_loss, - mean_violations=mean(V), - new_max_violation=maximum(V), - mean_objs=mean(objs), + return total_loss, + st_new, + ( + total_loss = total_loss, + mean_violations = mean(V), + new_max_violation = maximum(V), + mean_objs = mean(objs), ) end end @@ -100,7 +100,7 @@ Fields: mutable struct TrainingStepLoop loss_fn::Function stopping_criteria::Vector{Function} - hyperparameters::Dict{Symbol, Any} + hyperparameters::Dict{Symbol,Any} parameter_update_fns::Vector{Function} reconcile_state::Function pre_hook::Function @@ -112,7 +112,14 @@ end Default pre-hook function for the primal model in the L2O-ALM algorithm. This function performs a forward pass through the dual model to obtain the dual predictions. """ -function _pre_hook_primal(θ, primal_model, train_state_primal, dual_model, train_state_dual, bm) +function _pre_hook_primal( + θ, + primal_model, + train_state_primal, + dual_model, + train_state_dual, + bm, +) # Forward pass for dual model dual_hat_k, _ = dual_model(θ, train_state_dual.parameters, train_state_dual.states) @@ -125,15 +132,22 @@ end Default pre-hook function for the dual model in the L2O-ALM algorithm. This function performs a forward pass through the primal model to obtain the predicted state and constraints. """ -function _pre_hook_dual(θ, primal_model, train_state_primal, dual_model, train_state_dual, bm) +function _pre_hook_dual( + θ, + primal_model, + train_state_primal, + dual_model, + train_state_dual, + bm, +) # # Forward pass for primal model X̂, _ = primal_model(θ, train_state_primal.parameters, train_state_primal.states) gh = constraints!(bm, X̂, Θ) - + # Forward pass for dual model dual_hat, _ = dual_model(θ, train_state_dual.parameters, train_state_dual.states) - return (dual_hat, gh,) + return (dual_hat, gh) end """ @@ -149,10 +163,10 @@ function _reconcile_alm_primal_state(batch_states::Vector{NamedTuple}) mean_objs = mean([s.mean_objs for s in batch_states]) mean_loss = mean([s.total_loss for s in batch_states]) return (; - new_max_violation=max_violation, - mean_violations=mean_violations, - mean_objs=mean_objs, - total_loss=mean_loss, + new_max_violation = max_violation, + mean_violations = mean_violations, + mean_objs = mean_objs, + total_loss = mean_loss, ) end @@ -164,7 +178,7 @@ This function computes the mean dual loss from the batch states. """ function _reconcile_alm_dual_state(batch_states::Vector{NamedTuple}) dual_loss = mean([s.dual_loss for s in batch_states]) - return (dual_loss=dual_loss,) + return (dual_loss = dual_loss,) end """ @@ -173,7 +187,11 @@ end Default function to update the hyperparameter ρ in the ALM algorithm. This function increases ρ by a factor of α if the new maximum violation exceeds τ times the previous maximum violation. """ -function _update_ALM_ρ!(hpm_primal::Dict{Symbol, Any}, hpm_dual::Dict{Symbol, Any}, current_state::NamedTuple) +function _update_ALM_ρ!( + hpm_primal::Dict{Symbol,Any}, + hpm_dual::Dict{Symbol,Any}, + current_state::NamedTuple, +) if current_state.new_max_violation > hpm_primal.τ * hpm_primal.max_violation hpm_primal[:ρ] = min(hpm_primal[:ρmax], hpm_primal[:ρ] * hpm_primal[:α]) hpm_dual[:ρ] = hpm_primal[:ρ] # Ensure dual model uses the same ρ @@ -191,7 +209,7 @@ function _default_primal_loop(bm::BatchModel) return TrainingStepLoop( LagrangianPrimalLoss(bm), [(iter, current_state, hpm) -> iter >= 100 ? true : false], - Dict{Symbol, Any}( + Dict{Symbol,Any}( :ρ => 1.0, :ρmax => 1e6, :τ => 0.8, @@ -200,7 +218,7 @@ function _default_primal_loop(bm::BatchModel) ), [_update_ALM_ρ!], _reconcile_alm_primal_state, - _pre_hook_primal + _pre_hook_primal, ) end @@ -213,13 +231,10 @@ function _default_dual_loop(num_equal::Int) return TrainingStepLoop( LagrangianDualLoss(num_equal), [(iter, current_state, hpm) -> iter >= 100 ? true : false], - Dict{Symbol, Any}( - :max_dual => 1e6, - :ρ => 1.0, - ), + Dict{Symbol,Any}(:max_dual => 1e6, :ρ => 1.0), [], _reconcile_alm_dual_state, - _pre_hook_dual + _pre_hook_dual, ) end @@ -249,7 +264,7 @@ function L2OALM_epoch!( train_state_dual::Lux.Training.TrainState, training_step_loop_primal::TrainingStepLoop, training_step_loop_dual::TrainingStepLoop, - data + data, ) iter_primal = 1 iter_dual = 1 @@ -258,36 +273,73 @@ function L2OALM_epoch!( current_state_dual = (;) # primal loop - while all(stopping_criterion(iter_primal, current_state_primal, training_step_loop_primal.hyperparameters) for stopping_criterion in training_step_loop_primal.stopping_criteria) + while all( + stopping_criterion( + iter_primal, + current_state_primal, + training_step_loop_primal.hyperparameters, + ) for stopping_criterion in training_step_loop_primal.stopping_criteria + ) current_states_primal = Vector{NamedTuple}(undef, num_batches) iter_batch = 1 for (θ) in data _, loss_val, stats, train_state_primal = Training.single_train_step!( AutoZygote(), # AD backend training_step_loop_primal.loss_fn, # Loss function - (θ, training_step_loop_primal.hyperparameters, training_step_loop_primal.pre_hook(θ, primal_model, train_state_primal, dual_model, train_state_dual)...), # Data - train_state_primal # Training state + ( + θ, + training_step_loop_primal.hyperparameters, + training_step_loop_primal.pre_hook( + θ, + primal_model, + train_state_primal, + dual_model, + train_state_dual, + )..., + ), # Data + train_state_primal, # Training state ) current_states_primal[iter_batch] = stats iter_batch += 1 end - current_state_primal = training_step_loop_primal.reconcile_state(current_states_primal) + current_state_primal = + training_step_loop_primal.reconcile_state(current_states_primal) iter_primal += 1 end for fn in training_step_loop_primal.parameter_update_fns - fn(training_step_loop_primal.hyperparameters, training_step_loop_dual.hyperparameters, current_state_primal) + fn( + training_step_loop_primal.hyperparameters, + training_step_loop_dual.hyperparameters, + current_state_primal, + ) end # dual loop - while all(stopping_criterion(iter_dual, current_state_dual, training_step_loop_dual.hyperparameters) for stopping_criterion in training_step_loop_dual.stopping_criteria) + while all( + stopping_criterion( + iter_dual, + current_state_dual, + training_step_loop_dual.hyperparameters, + ) for stopping_criterion in training_step_loop_dual.stopping_criteria + ) current_states_dual = Vector{NamedTuple}(undef, num_batches) iter_batch = 1 for (θ) in data _, loss_val, stats, train_state_dual = Training.single_train_step!( AutoZygote(), # AD backend training_step_loop_dual.loss_fn, # Loss function - (θ, training_step_loop_dual.hyperparameters, training_step_loop_dual.pre_hook(θ, primal_model, train_state_primal, dual_model, train_state_dual)...), # Data - train_state_dual # Training state + ( + θ, + training_step_loop_dual.hyperparameters, + training_step_loop_dual.pre_hook( + θ, + primal_model, + train_state_primal, + dual_model, + train_state_dual, + )..., + ), # Data + train_state_dual, # Training state ) current_states_dual[iter_batch] = stats iter_batch += 1 @@ -296,7 +348,11 @@ function L2OALM_epoch!( iter_dual += 1 end for fn in training_step_loop_dual.parameter_update_fns - fn(training_step_loop_primal.hyperparameters, training_step_loop_dual.hyperparameters, current_state_dual) + fn( + training_step_loop_primal.hyperparameters, + training_step_loop_dual.hyperparameters, + current_state_dual, + ) end return end @@ -334,12 +390,25 @@ function L2OALM_train!( train_state_primal::Lux.Training.TrainState, train_state_dual::Lux.Training.TrainState, data; - training_step_loop_primal::TrainingStepLoop=_default_primal_loop(bm), - training_step_loop_dual::TrainingStepLoop=_default_dual_loop(num_equal), - stopping_criteria::Vector{F}=[(iter, primal_model, dual_model, tr_st_primal, tr_st_dual, hpm_primal) -> iter >= 100 ? true : false], -) where F<:Function + training_step_loop_primal::TrainingStepLoop = _default_primal_loop(bm), + training_step_loop_dual::TrainingStepLoop = _default_dual_loop(num_equal), + stopping_criteria::Vector{F} = [ + (iter, primal_model, dual_model, tr_st_primal, tr_st_dual, hpm_primal) -> + iter >= 100 ? true : false, + ], +) where {F<:Function} iter = 1 - while all(stopping_criterion(iter, primal_model, dual_model, train_state_primal, train_state_dual, training_step_loop_primal.hyperparameters, training_step_loop_dual.hyperparameters) for stopping_criterion in stopping_criteria) + while all( + stopping_criterion( + iter, + primal_model, + dual_model, + train_state_primal, + train_state_dual, + training_step_loop_primal.hyperparameters, + training_step_loop_dual.hyperparameters, + ) for stopping_criterion in stopping_criteria + ) L2OALM_epoch!( primal_model, train_state_primal, @@ -347,7 +416,7 @@ function L2OALM_train!( train_state_dual, training_step_loop_primal, training_step_loop_dual, - data + data, ) iter += 1 end diff --git a/test/power.jl b/test/power.jl index 0e021df..6644468 100644 --- a/test/power.jl +++ b/test/power.jl @@ -16,31 +16,31 @@ end _convert_array(data::N, backend) where {names,N<:NamedTuple{names}} = NamedTuple{names}(ExaModels.convert_array(d, backend) for d in data) -function _parse_ac_data_raw(filename; T=Float64) +function _parse_ac_data_raw(filename; T = Float64) ref = _build_power_ref(filename) # FIXME: only parse once - arcdict = Dict(a => k for (k, a) in enumerate(ref[:arcs])) - busdict = Dict(k => i for (i, (k, _)) in enumerate(ref[:bus])) - gendict = Dict(k => i for (i, (k, _)) in enumerate(ref[:gen])) + arcdict = Dict(a => k for (k, a) in enumerate(ref[:arcs])) + busdict = Dict(k => i for (i, (k, _)) in enumerate(ref[:bus])) + gendict = Dict(k => i for (i, (k, _)) in enumerate(ref[:gen])) branchdict = Dict(k => i for (i, (k, _)) in enumerate(ref[:branch])) return ( bus = [ begin - loads = [ref[:load][l] for l in ref[:bus_loads][k]] - shunts = [ref[:shunt][s] for s in ref[:bus_shunts][k]] - pd = T(sum(load["pd"] for load in loads; init = 0.0)) - qd = T(sum(load["qd"] for load in loads; init = 0.0)) - gs = T(sum(shunt["gs"] for shunt in shunts; init = 0.0)) - bs = T(sum(shunt["bs"] for shunt in shunts; init = 0.0)) + loads = [ref[:load][l] for l in ref[:bus_loads][k]] + shunts = [ref[:shunt][s] for s in ref[:bus_shunts][k]] + pd = T(sum(load["pd"] for load in loads; init = 0.0)) + qd = T(sum(load["qd"] for load in loads; init = 0.0)) + gs = T(sum(shunt["gs"] for shunt in shunts; init = 0.0)) + bs = T(sum(shunt["bs"] for shunt in shunts; init = 0.0)) (i = busdict[k], pd = pd, gs = gs, qd = qd, bs = bs) end for (k, _) in ref[:bus] ], gen = [ ( - i = gendict[k], + i = gendict[k], cost1 = T(v["cost"][1]), cost2 = T(v["cost"][2]), cost3 = T(v["cost"][3]), - bus = busdict[v["gen_bus"]], + bus = busdict[v["gen_bus"]], ) for (k, v) in ref[:gen] ], arc = [ @@ -53,65 +53,72 @@ function _parse_ac_data_raw(filename; T=Float64) f_idx = arcdict[i, branch["f_bus"], branch["t_bus"]] t_idx = arcdict[i, branch["t_bus"], branch["f_bus"]] - g, b = PowerModels.calc_branch_y(branch) + g, b = PowerModels.calc_branch_y(branch) tr, ti = PowerModels.calc_branch_t(branch) - ttm = tr^2 + ti^2 + ttm = tr^2 + ti^2 - g_fr = branch["g_fr"]; b_fr = branch["b_fr"] - g_to = branch["g_to"]; b_to = branch["b_to"] + g_fr = branch["g_fr"] + b_fr = branch["b_fr"] + g_to = branch["g_to"] + b_to = branch["b_to"] ( - i = branchdict[i], - j = 1, - f_idx = f_idx, - t_idx = t_idx, - f_bus = busdict[branch["f_bus"]], - t_bus = busdict[branch["t_bus"]], - c1 = T((-g * tr - b * ti) / ttm), - c2 = T((-b * tr + g * ti) / ttm), - c3 = T((-g * tr + b * ti) / ttm), - c4 = T((-b * tr - g * ti) / ttm), - c5 = T((g + g_fr) / ttm), - c6 = T((b + b_fr) / ttm), - c7 = T((g + g_to)), - c8 = T((b + b_to)), + i = branchdict[i], + j = 1, + f_idx = f_idx, + t_idx = t_idx, + f_bus = busdict[branch["f_bus"]], + t_bus = busdict[branch["t_bus"]], + c1 = T((-g * tr - b * ti) / ttm), + c2 = T((-b * tr + g * ti) / ttm), + c3 = T((-g * tr + b * ti) / ttm), + c4 = T((-b * tr - g * ti) / ttm), + c5 = T((g + g_fr) / ttm), + c6 = T((b + b_fr) / ttm), + c7 = T((g + g_to)), + c8 = T((b + b_to)), rate_a_sq = T(branch["rate_a"]^2), ) end for (i, branch_raw) in ref[:branch] ], ref_buses = [busdict[i] for (i, _) in ref[:ref_buses]], - vmax = [T(v["vmax"]) for (_, v) in ref[:bus]], - vmin = [T(v["vmin"]) for (_, v) in ref[:bus]], - pmax = [T(v["pmax"]) for (_, v) in ref[:gen]], - pmin = [T(v["pmin"]) for (_, v) in ref[:gen]], - qmax = [T(v["qmax"]) for (_, v) in ref[:gen]], - qmin = [T(v["qmin"]) for (_, v) in ref[:gen]], - rate_a = [T(ref[:branch][l]["rate_a"]) for (l, _, _) in ref[:arcs]], - angmax = [T(b["angmax"]) for (_, b) in ref[:branch]], - angmin = [T(b["angmin"]) for (_, b) in ref[:branch]], + vmax = [T(v["vmax"]) for (_, v) in ref[:bus]], + vmin = [T(v["vmin"]) for (_, v) in ref[:bus]], + pmax = [T(v["pmax"]) for (_, v) in ref[:gen]], + pmin = [T(v["pmin"]) for (_, v) in ref[:gen]], + qmax = [T(v["qmax"]) for (_, v) in ref[:gen]], + qmin = [T(v["qmin"]) for (_, v) in ref[:gen]], + rate_a = [T(ref[:branch][l]["rate_a"]) for (l, _, _) in ref[:arcs]], + angmax = [T(b["angmax"]) for (_, b) in ref[:branch]], + angmin = [T(b["angmin"]) for (_, b) in ref[:branch]], ) end _parse_ac_data(filename) = _parse_ac_data_raw(filename) -function _parse_ac_data(filename, backend; T=Float64) - _convert_array(_parse_ac_data_raw(filename, T=T), backend) +function _parse_ac_data(filename, backend; T = Float64) + _convert_array(_parse_ac_data_raw(filename, T = T), backend) end # Parametric version -function create_parametric_ac_power_model(filename::String = "pglib_opf_case14_ieee.m"; - prod::Bool = true, backend = OpenCLBackend(), T=Float64, kwargs...) - data = _parse_ac_data(filename, backend, T=T) +function create_parametric_ac_power_model( + filename::String = "pglib_opf_case14_ieee.m"; + prod::Bool = true, + backend = OpenCLBackend(), + T = Float64, + kwargs..., +) + data = _parse_ac_data(filename, backend, T = T) c = ExaCore(T; backend = backend) va = variable(c, length(data.bus);) vm = variable( - c, - length(data.bus); - start = fill!(similar(data.bus, T), 1.0), - lvar = data.vmin, - uvar = data.vmax, - ) + c, + length(data.bus); + start = fill!(similar(data.bus, T), 1.0), + lvar = data.vmin, + uvar = data.vmax, + ) pg = variable(c, length(data.gen); lvar = data.pmin, uvar = data.pmax) qg = variable(c, length(data.gen); lvar = data.qmin, uvar = data.qmax) @@ -129,40 +136,52 @@ function create_parametric_ac_power_model(filename::String = "pglib_opf_case14_i # Branch power-flow equations --------------------------------------------- constraint( c, - (p[b.f_idx] - b.c5 * vm[b.f_bus]^2 - - b.c3 * (vm[b.f_bus] * vm[b.t_bus] * cos(va[b.f_bus] - va[b.t_bus])) - - b.c4 * (vm[b.f_bus] * vm[b.t_bus] * sin(va[b.f_bus] - va[b.t_bus])) for - b in data.branch), + ( + p[b.f_idx] - b.c5 * vm[b.f_bus]^2 - + b.c3 * (vm[b.f_bus] * vm[b.t_bus] * cos(va[b.f_bus] - va[b.t_bus])) - + b.c4 * (vm[b.f_bus] * vm[b.t_bus] * sin(va[b.f_bus] - va[b.t_bus])) for + b in data.branch + ), ) constraint( c, - (q[b.f_idx] + b.c6 * vm[b.f_bus]^2 + - b.c4 * (vm[b.f_bus] * vm[b.t_bus] * cos(va[b.f_bus] - va[b.t_bus])) - - b.c3 * (vm[b.f_bus] * vm[b.t_bus] * sin(va[b.f_bus] - va[b.t_bus])) for - b in data.branch), + ( + q[b.f_idx] + + b.c6 * vm[b.f_bus]^2 + + b.c4 * (vm[b.f_bus] * vm[b.t_bus] * cos(va[b.f_bus] - va[b.t_bus])) - + b.c3 * (vm[b.f_bus] * vm[b.t_bus] * sin(va[b.f_bus] - va[b.t_bus])) for + b in data.branch + ), ) constraint( c, - (p[b.t_idx] - b.c7 * vm[b.t_bus]^2 - - b.c1 * (vm[b.t_bus] * vm[b.f_bus] * cos(va[b.t_bus] - va[b.f_bus])) - - b.c2 * (vm[b.t_bus] * vm[b.f_bus] * sin(va[b.t_bus] - va[b.f_bus])) for - b in data.branch), + ( + p[b.t_idx] - b.c7 * vm[b.t_bus]^2 - + b.c1 * (vm[b.t_bus] * vm[b.f_bus] * cos(va[b.t_bus] - va[b.f_bus])) - + b.c2 * (vm[b.t_bus] * vm[b.f_bus] * sin(va[b.t_bus] - va[b.f_bus])) for + b in data.branch + ), ) constraint( c, - (q[b.t_idx] + b.c8 * vm[b.t_bus]^2 + - b.c2 * (vm[b.t_bus] * vm[b.f_bus] * cos(va[b.t_bus] - va[b.f_bus])) - - b.c1 * (vm[b.t_bus] * vm[b.f_bus] * sin(va[b.t_bus] - va[b.f_bus])) for - b in data.branch), + ( + q[b.t_idx] + + b.c8 * vm[b.t_bus]^2 + + b.c2 * (vm[b.t_bus] * vm[b.f_bus] * cos(va[b.t_bus] - va[b.f_bus])) - + b.c1 * (vm[b.t_bus] * vm[b.f_bus] * sin(va[b.t_bus] - va[b.f_bus])) for + b in data.branch + ), ) # Angle difference limits -------------------------------------------------- constraint( c, - va[b.f_bus] - va[b.t_bus] for b in data.branch; lcon = data.angmin, ucon = data.angmax, + va[b.f_bus] - va[b.t_bus] for b in data.branch; + lcon = data.angmin, + ucon = data.angmax, ) # Apparent power thermal limits ------------------------------------------- @@ -178,12 +197,14 @@ function create_parametric_ac_power_model(filename::String = "pglib_opf_case14_i ) # Power balance at each bus ----------------------------------------------- - load_balance_p = constraint(c, b.pd * load_multiplier[b.i] + b.gs * vm[b.i]^2 for b in data.bus) - load_balance_q = constraint(c, b.qd * load_multiplier[b.i] - b.bs * vm[b.i]^2 for b in data.bus) + load_balance_p = + constraint(c, b.pd * load_multiplier[b.i] + b.gs * vm[b.i]^2 for b in data.bus) + load_balance_q = + constraint(c, b.qd * load_multiplier[b.i] - b.bs * vm[b.i]^2 for b in data.bus) # Map arc & generator variables into the bus balance equations - constraint!(c, load_balance_p, a.bus => p[a.i] for a in data.arc) - constraint!(c, load_balance_q, a.bus => q[a.i] for a in data.arc) + constraint!(c, load_balance_p, a.bus => p[a.i] for a in data.arc) + constraint!(c, load_balance_q, a.bus => q[a.i] for a in data.arc) constraint!(c, load_balance_p, g.bus => -pg[g.i] for g in data.gen) constraint!(c, load_balance_q, g.bus => -qg[g.i] for g in data.gen) diff --git a/test/runtests.jl b/test/runtests.jl index 814ea0b..7d59944 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,10 +33,10 @@ include("power.jl") """ # Combine all layers: input size, hidden sizes, output size layer_sizes = [num_p; hidden_layers; num_y] - + # Build up a list of Dense layers dense_layers = Any[] - for i in 1:(length(layer_sizes)-1) + for i = 1:(length(layer_sizes)-1) if i < length(layer_sizes) - 1 # Hidden layers with activation push!(dense_layers, Dense(layer_sizes[i], layer_sizes[i+1], activation)) @@ -45,29 +45,32 @@ include("power.jl") push!(dense_layers, Dense(layer_sizes[i], layer_sizes[i+1])) end end - + return Chain(dense_layers...) end - - function test_alm_training(; filename="pglib_opf_case14_ieee.m", dev_gpu = gpu_device(), - backend=CPU(), + + function test_alm_training(; + filename = "pglib_opf_case14_ieee.m", + dev_gpu = gpu_device(), + backend = CPU(), batch_size = 32, dataset_size = 3200, rng = Random.default_rng(), T = Float64, ) - model, nbus, ngen, blines = create_parametric_ac_power_model(filename; backend = backend, T=T) - bm_train = BNK.BatchModel(model, batch_size, config=BNK.BatchModelConfig(:full)) - bm_test = BNK.BatchModel(model, dataset_size, config=BNK.BatchModelConfig(:full)) - + model, nbus, ngen, blines = + create_parametric_ac_power_model(filename; backend = backend, T = T) + bm_train = BNK.BatchModel(model, batch_size, config = BNK.BatchModelConfig(:full)) + bm_test = BNK.BatchModel(model, dataset_size, config = BNK.BatchModelConfig(:full)) + nvar = model.meta.nvar ncon = model.meta.ncon nθ = length(model.θ) num_equal = nbus * 2 - + Θ_train = randn(T, nθ, dataset_size) |> dev_gpu Θ_test = randn(T, nθ, dataset_size) |> dev_gpu - + primal_model = feed_forward_builder(nθ, nvar, [320, 320]) ps_primal, st_primal = Lux.setup(rng, primal_model) ps_primal = ps_primal |> dev_gpu @@ -77,47 +80,63 @@ include("power.jl") ps_dual, st_dual = Lux.setup(rng, dual_model) ps_dual = ps_dual |> dev_gpu st_dual = st_dual |> dev_gpu - - X̂ , _ = primal_model(Θ_test, ps_primal, st_primal) - + + X̂, _ = primal_model(Θ_test, ps_primal, st_primal) + y = BNK.objective!(bm_test, X̂, Θ_test) - + @test length(y) == dataset_size Vc, Vb = BNK.all_violations!(bm_test, X̂, Θ_test) @test size(Vc) == (ncon, dataset_size) @test size(Vb) == (nvar, dataset_size) - + # lagrangian_prev = sum(y) + 1000 * sum(Vc) + 1000 * sum(Vb) - - train_state_primal = Training.TrainState(primal_model, ps_primal, st_primal, Optimisers.Adam(1e-5)) - train_state_dual = Training.TrainState(dual_model, ps_dual, st_dual, Optimisers.Adam(1e-5)) - - data = DataLoader((Θ_train); batchsize=batch_size, shuffle=true) .|> dev_gpu + + train_state_primal = + Training.TrainState(primal_model, ps_primal, st_primal, Optimisers.Adam(1e-5)) + train_state_dual = + Training.TrainState(dual_model, ps_dual, st_dual, Optimisers.Adam(1e-5)) + + data = DataLoader((Θ_train); batchsize = batch_size, shuffle = true) .|> dev_gpu function validation_testset( - iter, primal_model, dual_model, train_state_primal, train_state_dual, hpm_primal, hpm_dual; max_dual=1e6 + iter, + primal_model, + dual_model, + train_state_primal, + train_state_dual, + hpm_primal, + hpm_dual; + max_dual = 1e6, ) ρ = hpm_primal[:ρ] - X̂_test , _ = primal_model(Θ_test, train_state_primal.parameters, train_state_primal.states) + X̂_test, _ = primal_model( + Θ_test, + train_state_primal.parameters, + train_state_primal.states, + ) objs_test = BNK.objective!(bm_test, X̂_test, Θ_test) Vc_test, Vb_test = BNK.all_violations!(bm_test, X̂_test, Θ_test) gh_test = BNK.constraints!(bm_test, X̂_test, Θ_test) - dual_hat, _ = dual_model(Θ_test, train_state_dual.parameters, train_state_dual.states) + dual_hat, _ = + dual_model(Θ_test, train_state_dual.parameters, train_state_dual.states) # Separate bound and equality constraints - gh_bound = gh_test[1:end-num_equal,:] - gh_equal = gh_test[end-num_equal+1:end,:] - dual_hat_bound = dual_hat[1:end-num_equal,:] - dual_hat_equal = dual_hat[end-num_equal+1:end,:] - + gh_bound = gh_test[1:end-num_equal, :] + gh_equal = gh_test[end-num_equal+1:end, :] + dual_hat_bound = dual_hat[1:end-num_equal, :] + dual_hat_equal = dual_hat[end-num_equal+1:end, :] + # Target for dual variables dual_target = vcat( min.(max.(dual_hat_bound + ρ .* gh_bound, 0), max_dual), - min.(dual_hat_equal + ρ .* gh_equal, max_dual) + min.(dual_hat_equal + ρ .* gh_equal, max_dual), ) - - dual_loss = mean((dual_hat .- dual_target).^2) - - @info "Validation Testset: Iteration $iter" mean(objs_test) mean(Vc_test) mean(Vb_test) dual_loss + + dual_loss = mean((dual_hat .- dual_target) .^ 2) + + @info "Validation Testset: Iteration $iter" mean(objs_test) mean(Vc_test) mean( + Vb_test, + ) dual_loss return iter >= 100 ? true : false end @@ -129,17 +148,22 @@ include("power.jl") train_state_primal, train_state_dual, data, - stopping_criteria=[validation_testset], + stopping_criteria = [validation_testset], ) end - + @testset "Penalty Training" begin backend, dev = if haskey(ENV, "BNK_TEST_CUDA") CUDABackend(), gpu_device() else CPU(), cpu_device() end - - test_alm_training(; filename="pglib_opf_case14_ieee.m", dev_gpu = dev, backend=backend, T=Float32) + + test_alm_training(; + filename = "pglib_opf_case14_ieee.m", + dev_gpu = dev, + backend = backend, + T = Float32, + ) end end From 0797bb71d848a11b29ba61fa4083792196c4b2b0 Mon Sep 17 00:00:00 2001 From: Andrew Date: Tue, 22 Jul 2025 17:49:41 -0400 Subject: [PATCH 14/15] add source bnk --- Project.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Project.toml b/Project.toml index d05df0f..66548ca 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,9 @@ LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +[sources] +BatchNLPKernels = {url = "https://github.com/klamike/BatchNLPKernels.jl"} + [compat] julia = "1.6.7" From b894dd3b5b9afe489f72db7923fd5323e1ca2995 Mon Sep 17 00:00:00 2001 From: Andrew Date: Tue, 29 Jul 2025 19:28:32 -0400 Subject: [PATCH 15/15] change API --- Project.toml | 7 +- src/L2OALM.jl | 503 +++++++++++++++++++++-------------------------- test/runtests.jl | 67 ++----- 3 files changed, 247 insertions(+), 330 deletions(-) diff --git a/Project.toml b/Project.toml index 66548ca..f3699aa 100644 --- a/Project.toml +++ b/Project.toml @@ -6,15 +6,13 @@ version = "1.0.0-DEV" [deps] BatchNLPKernels = "7145f916-0e30-4c9d-93a2-b32b6056125d" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ExaModels = "1037b233-b668-4ce9-9b63-f9f681f55dd2" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -[sources] -BatchNLPKernels = {url = "https://github.com/klamike/BatchNLPKernels.jl"} - [compat] julia = "1.6.7" @@ -27,5 +25,8 @@ PowerModels = "c36e90e8-916a-50a6-bd94-075b64ef4655" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[sources.BatchNLPKernels] +url = "https://github.com/klamike/BatchNLPKernels.jl" + [targets] test = ["Test", "PowerModels", "PGLib", "Random", "MLUtils", "KernelAbstractions", "GPUArraysCore"] diff --git a/src/L2OALM.jl b/src/L2OALM.jl index 4e1e707..ef6836c 100644 --- a/src/L2OALM.jl +++ b/src/L2OALM.jl @@ -8,29 +8,155 @@ using LuxCUDA using Lux.Training using CUDA using Statistics +using ChainRules: @ignore_derivatives -export LagrangianDualLoss, - LagrangianPrimalLoss, TrainingStepLoop, L2OALM_epoch!, L2OALM_train! +export ALMMethod, + ALMTrainer, dual_loss, primal_loss, single_train_step! """ - LagrangianDualLoss(;max_dual=1e6) + AbstractL2OMethod + +Abstract type for Learning to Optimize (L2O) methods. +""" +abstract type AbstractL2OMethod end + +""" + AbstractPrimalDualMethod + +Abstract type for Primal-Dual Learning to Optimize (L2O) methods. +""" +abstract type AbstractPrimalDualMethod <: AbstractL2OMethod end + +""" + ALMMethod{T<:Real} <: AbstractPrimalDualMethod + +Augmented Lagrangian Method (ALM) for primal-dual Learning to Optimize (L2O) methods. +""" +struct ALMMethod{T<:Real} <: AbstractPrimalDualMethod + batch_model::BatchModel + max_dual::T + ρmax::T + τ::T + α::T + num_equal::Int # TODO: There should be a way to get this from the batch model + + function ALMMethod(batch_model::BatchModel, max_dual::T, ρmax::T, τ::T, α::T, num_equal::Int) where {T<:Real} + new{T}(batch_model, max_dual, ρmax, τ, α, num_equal) + end +end + +""" + ALMMethod(; batch_model::BatchModel, num_equal::Int, max_dual::T = 1e6, ρmax::T = 1e6, τ::T = 0.8, α::T = 10.0) + +Constructor for the Augmented Lagrangian Method (ALM) for primal-dual Learning to Optimize (L2O) methods. +This function initializes the ALM method with a batch model, maximum dual variable values, maximum learning rate `ρ`, +threshold for the parameter updater `τ`, meta learning rate `α`, and the number of equality constraints. +""" +function ALMMethod(; + batch_model::BatchModel, + num_equal::Int, + max_dual::T = 1e6, + ρmax::T = 1e6, + τ::T = 0.8, + α::T = 10.0, +) where {T<:Real} + return ALMMethod(batch_model, max_dual, ρmax, τ, α, num_equal) +end + +""" + AbstractL2OTrainer + +An abstract type for a structure that holds the training state for Learning to Optimize (L2O) methods. +""" +abstract type AbstractL2OTrainer end + +""" + AbstractPrimalDualTrainer + +An abstract type for a structure that holds the training state for Primal-Dual Learning to Optimize (L2O) methods. +""" +abstract type AbstractPrimalDualTrainer <: AbstractL2OTrainer end + +""" + ALMTrainer{T<:Real} <: AbstractPrimalDualTrainer + +A structure that holds the training state for the Augmented Lagrangian Method (ALM) for primal-dual Learning to Optimize (L2O) methods. +""" +mutable struct ALMTrainer{T<:Real} <: AbstractPrimalDualTrainer + primal_model::Lux.Chain + primal_training_state::Lux.Training.TrainState + dual_model::Lux.Chain + dual_training_state::Lux.Training.TrainState + ρ::T + prev_dual_training_state::Lux.Training.TrainState + max_violations::T + mean_violations::T + mean_objs::T + total_loss::T + dual_loss::T + + function ALMTrainer( + primal_model::Lux.Chain, + primal_training_state::Lux.Training.TrainState, + dual_model::Lux.Chain, + dual_training_state::Lux.Training.TrainState, + ρ::T = 1.0, + prev_dual_training_state::Lux.Training.TrainState = deepcopy(dual_training_state), + max_violations::T = Inf, + mean_violations::T = Inf, + mean_objs::T = Inf, + total_loss::T = Inf, + dual_loss::T = Inf, + ) where {T<:Real} + new{T}( + primal_model, primal_training_state, dual_model, dual_training_state, ρ, prev_dual_training_state, + max_violations, mean_violations, mean_objs, total_loss, dual_loss + ) + end +end + +function ALMTrainer(; + primal_model::Lux.Chain, + primal_training_state::Lux.Training.TrainState, + dual_model::Lux.Chain, + dual_training_state::Lux.Training.TrainState, + ρ::T = 1.0, +) where {T<:Real} + return ALMTrainer{T}( + primal_model, primal_training_state, dual_model, dual_training_state, ρ, + ) +end + +""" + dual_loss(method::ALMMethod) Returns a function that computes the MSE loss for the (dual-)model predicting lagrangian dual variables -from constraint evaluations `gh` and the current dual predictions `dual_hat_k`. +from constraint evaluations and the last dual predictions. Target is calculated using the augmented lagrangian method. Target dual variables are clipped from zero to `max_dual`. - -Keywords: - - `max_dual`: Maximum value for the target dual variables. """ -function LagrangianDualLoss(num_equal::Int; max_dual = 1e6) +function dual_loss(method::ALMMethod) + bm = method.batch_model + max_dual = method.max_dual + num_equal = method.num_equal return (dual_model, ps_dual, st_dual, data) -> begin - x, hpm, dual_hat_k, gh = data - ρ = hpm[:ρ] + Θ, trainer = data + ρ = trainer.ρ + prev_dual_training_state = trainer.prev_dual_training_state + primal_training_state = trainer.primal_training_state + # Get current dual predictions - dual_hat, st_dual_new = dual_model(x, ps_dual, st_dual) + dual_hat, st_dual_new = dual_model(Θ, ps_dual, st_dual) + + # Get previous dual predictions + dual_hat_k, _ = @ignore_derivatives dual_model(Θ, prev_dual_training_state.parameters, prev_dual_training_state.states) # Separate bound and equality constraints + # # Forward pass for primal model + X̂, _ = trainer.primal_model(θ, primal_training_state.parameters, primal_training_state.states) + gh = constraints!(bm, X̂, Θ) + + @ignore_derivatives gh = trainer.constraints(Θ, trainer.ps_constraints, trainer.st_constraints) gh_bound = gh[1:end-num_equal, :] gh_equal = gh[end-num_equal+1:end, :] dual_hat_bound = dual_hat_k[1:end-num_equal, :] @@ -48,23 +174,24 @@ function LagrangianDualLoss(num_equal::Int; max_dual = 1e6) end """ - LagrangianPrimalLoss(bm::BatchModel) + primal_loss(method::ALMMethod) Returns a function that computes the augmented lagrangian primal loss -from current dual predictions `dual_hat` for the batch model `bm` under parameters `Θ`. - -Arguments: - - `bm`: A `BatchModel` instance that contains the model and batch configuration. +from current dual predictions for the batch model `bm` under parameters `Θ`. """ -function LagrangianPrimalLoss(bm::BatchModel) +function primal_loss(method::ALMMethod) + bm = method.batch_model return (model, ps, st, data) -> begin - Θ, hpm, dual_hat = data - ρ = hpm[:ρ] + Θ, trainer = data + ρ = trainer.ρ num_s = size(Θ, 2) # Forward pass for prediction X̂, st_new = model(Θ, ps, st) + # Get current dual predictions + dual_hat, _ = @ignore_derivatives trainer.dual_model(Θ, trainer.dual_training_state.parameters, trainer.dual_training_state.states) + # Calculate violations and objectives objs = BNK.objective!(bm, X̂, Θ) # gh = constraints!(bm, X̂, Θ) @@ -79,91 +206,26 @@ function LagrangianPrimalLoss(bm::BatchModel) ( total_loss = total_loss, mean_violations = mean(V), - new_max_violation = maximum(V), + max_violation = maximum(V), mean_objs = mean(objs), ) end end """ - TrainingStepLoop - -A structure to define a training step loop for the L2O-ALM algorithm. - -Fields: -- `loss_fn`: Function to compute the loss for the training step. -- `stopping_criteria`: Vector of functions that determine when to stop the training loop. -- `hyperparameters`: Dictionary of hyperparameters and hyper-states for the training step. -- `parameter_update_fns`: Vector of functions to update hyperparameters after each training step. -- `reconcile_state`: Function to reconcile the state after processing a batch of data. -""" -mutable struct TrainingStepLoop - loss_fn::Function - stopping_criteria::Vector{Function} - hyperparameters::Dict{Symbol,Any} - parameter_update_fns::Vector{Function} - reconcile_state::Function - pre_hook::Function -end - -""" - _pre_hook_primal(θ, primal_model, train_state_primal, dual_model, train_state_dual, bm) - -Default pre-hook function for the primal model in the L2O-ALM algorithm. -This function performs a forward pass through the dual model to obtain the dual predictions. -""" -function _pre_hook_primal( - θ, - primal_model, - train_state_primal, - dual_model, - train_state_dual, - bm, -) - # Forward pass for dual model - dual_hat_k, _ = dual_model(θ, train_state_dual.parameters, train_state_dual.states) - - return (dual_hat_k,) -end - -""" - _pre_hook_dual(θ, primal_model, train_state_primal, dual_model, train_state_dual, bm) - -Default pre-hook function for the dual model in the L2O-ALM algorithm. -This function performs a forward pass through the primal model to obtain the predicted state and constraints. -""" -function _pre_hook_dual( - θ, - primal_model, - train_state_primal, - dual_model, - train_state_dual, - bm, -) - # # Forward pass for primal model - X̂, _ = primal_model(θ, train_state_primal.parameters, train_state_primal.states) - gh = constraints!(bm, X̂, Θ) + reconcile_primal(::AbstractL2OTrainer, batch_states::Vector{NamedTuple}) - # Forward pass for dual model - dual_hat, _ = dual_model(θ, train_state_dual.parameters, train_state_dual.states) - - return (dual_hat, gh) -end - -""" - _reconcile_alm_primal_state(batch_states::Vector{NamedTuple}) - -Default function that reconciles the state of the primal model after processing a batch of data. +Reconciles the state of the primal model after processing a batch of data. This function computes the maximum violation, mean violations, mean objectives, and total loss from the batch states. """ -function _reconcile_alm_primal_state(batch_states::Vector{NamedTuple}) - max_violation = maximum([s.new_max_violation for s in batch_states]) +function reconcile_primal(::AbstractL2OTrainer, batch_states::Vector{NamedTuple}) + max_violation = maximum([s.max_violation for s in batch_states]) mean_violations = mean([s.mean_violations for s in batch_states]) mean_objs = mean([s.mean_objs for s in batch_states]) mean_loss = mean([s.total_loss for s in batch_states]) return (; - new_max_violation = max_violation, + max_violation = max_violation, mean_violations = mean_violations, mean_objs = mean_objs, total_loss = mean_loss, @@ -171,99 +233,91 @@ function _reconcile_alm_primal_state(batch_states::Vector{NamedTuple}) end """ - _reconcile_alm_dual_state(batch_states::Vector{NamedTuple}) + reconcile_dual(::AbstractL2OTrainer, batch_states::Vector{NamedTuple}) -Default function that reconciles the state of the dual model after processing a batch of data. +Reconciles the state of the dual model after processing a batch of data. This function computes the mean dual loss from the batch states. """ -function _reconcile_alm_dual_state(batch_states::Vector{NamedTuple}) - dual_loss = mean([s.dual_loss for s in batch_states]) - return (dual_loss = dual_loss,) +function reconcile_dual(::AbstractL2OTrainer, batch_states::Vector{NamedTuple}) + return (dual_loss = mean([s.dual_loss for s in batch_states]),) end """ - _update_ALM_ρ!(hpm::Dict{Symbol, Any}, current_state::NamedTuple) + update_trainer!(method::ALMMethod, trainer::ALMTrainer, + primal_state::NamedTuple, dual_state::NamedTuple) -Default function to update the hyperparameter ρ in the ALM algorithm. +Update the hyperparameters and states in the ALM algorithm. This function increases ρ by a factor of α if the new maximum violation exceeds τ times the previous maximum violation. """ -function _update_ALM_ρ!( - hpm_primal::Dict{Symbol,Any}, - hpm_dual::Dict{Symbol,Any}, - current_state::NamedTuple, +function update_trainer!( + method::ALMMethod, + trainer::ALMTrainer, + primal_state::NamedTuple, + dual_state::NamedTuple, ) - if current_state.new_max_violation > hpm_primal.τ * hpm_primal.max_violation - hpm_primal[:ρ] = min(hpm_primal[:ρmax], hpm_primal[:ρ] * hpm_primal[:α]) - hpm_dual[:ρ] = hpm_primal[:ρ] # Ensure dual model uses the same ρ + # Update primal state + new_max_violations = primal_state.max_violations + trainer.mean_violations = primal_state.mean_violations + trainer.mean_objs = primal_state.mean_objs + trainer.total_loss = primal_state.total_loss + # Update ρ if necessary + if new_max_violations > method.τ * trainer.max_violations + trainer.ρ = min(method.ρmax, trainer.ρ * method.α) end - hpm_primal.max_violation = current_state.new_max_violation + trainer.max_violations = new_max_violations + + # Update dual state + trainer.dual_loss = dual_state.dual_loss + trainer.prev_dual_training_state = deepcopy(trainer.dual_training_state) + return end """ - _default_primal_loop(bm::BatchModel) - -Returns a default `TrainingStepLoop` for the primal model in the L2O-ALM algorithm. -""" -function _default_primal_loop(bm::BatchModel) - return TrainingStepLoop( - LagrangianPrimalLoss(bm), - [(iter, current_state, hpm) -> iter >= 100 ? true : false], - Dict{Symbol,Any}( - :ρ => 1.0, - :ρmax => 1e6, - :τ => 0.8, - :α => 10.0, - :max_violation => 0.0, - ), - [_update_ALM_ρ!], - _reconcile_alm_primal_state, - _pre_hook_primal, - ) + _stopping_criterion(iter::Int, method::AbstractL2OMethod, trainer::AbstractL2OTrainer) + +Default stopping criterion for the L2O methods. +This function checks if the number of iterations has reached a predefined limit (100). +""" +function _stopping_criterion(iter::Int, ::M, ::N) where {M<:AbstractL2OMethod, N<:AbstractL2OTrainer} + return iter >= 100 ? true : false end """ - _default_dual_loop() + primal_stopping_criterion(iter::Int, method::AbstractL2OMethod, trainer::AbstractL2OTrainer) -Returns a default `TrainingStepLoop` for the dual model in the L2O-ALM algorithm. +Default stopping criterion for primal learning methods. +This function checks if the number of iterations has reached a predefined limit (100). """ -function _default_dual_loop(num_equal::Int) - return TrainingStepLoop( - LagrangianDualLoss(num_equal), - [(iter, current_state, hpm) -> iter >= 100 ? true : false], - Dict{Symbol,Any}(:max_dual => 1e6, :ρ => 1.0), - [], - _reconcile_alm_dual_state, - _pre_hook_dual, - ) +function primal_stopping_criterion(iter::Int, method::M, trainer::N) where {M<:AbstractL2OMethod, N<:AbstractL2OTrainer} + return _stopping_criterion(iter, method, trainer) end """ - L2OALM_epoch(bm::BatchModel, primal_model::Lux.Chain, train_state_primal::Lux.Training.TrainState, - dual_model::Lux.Chain, train_state_dual::Lux.Training.TrainState, - training_step_loop_primal::TrainingStepLoop=_default_primal_loop(bm), - training_step_loop_dual::TrainingStepLoop=_default_dual_loop(), - data) + dual_stopping_criterion(iter::Int, method::AbstractL2OMethod, trainer::AbstractL2OTrainer) + +Default stopping criterion for dual learning methods. +This function checks if the number of iterations has reached a predefined limit (100). +""" +function dual_stopping_criterion(iter::Int, method::M, trainer::N) where {M<:AbstractL2OMethod, N<:AbstractL2OTrainer} + return _stopping_criterion(iter, method, trainer) +end -Runs a single epoch of the L2O-ALM algorithm, training both primal and dual models. +""" + single_train_step!( + method::AbstractPrimalDualMethod, + trainer::AbstractPrimalDualTrainer, + data, +) -Arguments: -- `bm`: A `BatchModel` instance that contains the model and batch configuration. -- `primal_model`: The Lux model for the primal problem. -- `train_state_primal`: The training state for the primal model. -- `dual_model`: The Lux model for the dual problem. -- `train_state_dual`: The training state for the dual model. -- `training_step_loop_primal`: The training step loop for the primal model. -- `training_step_loop_dual`: The training step loop for the dual model. -- `data`: The training data, typically a collection of batches. +Performs a single training step for the primal-dual method. +This function loops through the primal learning method until the stopping criterion is met +with the dual model fixed, then the inverse is done with the dual learning method and finally +updates the trainer state. """ -function L2OALM_epoch!( - primal_model::Lux.Chain, - train_state_primal::Lux.Training.TrainState, - dual_model::Lux.Chain, - train_state_dual::Lux.Training.TrainState, - training_step_loop_primal::TrainingStepLoop, - training_step_loop_dual::TrainingStepLoop, +function single_train_step!( + method::AbstractPrimalDualMethod, + trainer::AbstractPrimalDualTrainer, data, ) iter_primal = 1 @@ -271,155 +325,48 @@ function L2OALM_epoch!( num_batches = length(data) current_state_primal = (;) current_state_dual = (;) + _primal_loss = primal_loss(method) + _dual_loss = dual_loss(method) + train_state_primal = trainer.primal_training_state + train_state_dual = trainer.dual_training_state # primal loop - while all( - stopping_criterion( - iter_primal, - current_state_primal, - training_step_loop_primal.hyperparameters, - ) for stopping_criterion in training_step_loop_primal.stopping_criteria - ) + while primal_stopping_criterion(iter_primal, method, trainer) current_states_primal = Vector{NamedTuple}(undef, num_batches) iter_batch = 1 for (θ) in data _, loss_val, stats, train_state_primal = Training.single_train_step!( AutoZygote(), # AD backend - training_step_loop_primal.loss_fn, # Loss function - ( - θ, - training_step_loop_primal.hyperparameters, - training_step_loop_primal.pre_hook( - θ, - primal_model, - train_state_primal, - dual_model, - train_state_dual, - )..., - ), # Data + _primal_loss, # Loss function + (θ, trainer), # Data train_state_primal, # Training state ) current_states_primal[iter_batch] = stats iter_batch += 1 end - current_state_primal = - training_step_loop_primal.reconcile_state(current_states_primal) + current_state_primal = reconcile_primal(trainer, current_states_primal) iter_primal += 1 end - for fn in training_step_loop_primal.parameter_update_fns - fn( - training_step_loop_primal.hyperparameters, - training_step_loop_dual.hyperparameters, - current_state_primal, - ) - end # dual loop - while all( - stopping_criterion( - iter_dual, - current_state_dual, - training_step_loop_dual.hyperparameters, - ) for stopping_criterion in training_step_loop_dual.stopping_criteria - ) + while dual_stopping_criterion(iter_dual, method, trainer) current_states_dual = Vector{NamedTuple}(undef, num_batches) iter_batch = 1 for (θ) in data _, loss_val, stats, train_state_dual = Training.single_train_step!( AutoZygote(), # AD backend - training_step_loop_dual.loss_fn, # Loss function - ( - θ, - training_step_loop_dual.hyperparameters, - training_step_loop_dual.pre_hook( - θ, - primal_model, - train_state_primal, - dual_model, - train_state_dual, - )..., - ), # Data + _dual_loss, # Loss function + (θ, trainer), # Data train_state_dual, # Training state ) current_states_dual[iter_batch] = stats iter_batch += 1 end - current_state_dual = training_step_loop_dual.reconcile_state(current_states_dual) + current_state_dual = reconcile_dual(trainer, current_states_dual) iter_dual += 1 end - for fn in training_step_loop_dual.parameter_update_fns - fn( - training_step_loop_primal.hyperparameters, - training_step_loop_dual.hyperparameters, - current_state_dual, - ) - end - return -end - -""" - L2OALM_train(bm::BatchModel, num_equal::Int, - primal_model::Lux.Chain, dual_model::Lux.Chain, - train_state_primal::Lux.Training.TrainState, train_state_dual::Lux.Training.TrainState, - training_step_loop_primal::TrainingStepLoop=_default_primal_loop(bm), - training_step_loop_dual::TrainingStepLoop=_default_dual_loop(), - stopping_criteria::Vector{Function}=[(iter, primal_model::Lux.Chain, dual_model::Lux.Chain, - train_state_primal::Lux.Training.TrainState, train_state_dual::Lux.Training.TrainState) -> iter >= 100 ? true : false], - data - ) - -Runs the L2O-ALM training algorithm until the stopping criteria are met. - -Arguments: -- `bm`: A `BatchModel` instance that contains the model and batch configuration. -- `num_equal`: The number of equality constraints in the problem, used for dual loss calculation. -- `primal_model`: The Lux model for the primal problem. -- `dual_model`: The Lux model for the dual problem. -- `train_state_primal`: The training state for the primal model. -- `train_state_dual`: The training state for the dual model. -- `data`: The training data, typically a collection of batches. -- `training_step_loop_primal`: The training step loop for the primal model. -- `training_step_loop_dual`: The training step loop for the dual model. -- `stopping_criteria`: A vector of functions that determine when to stop the training loop. -""" -function L2OALM_train!( - bm::BatchModel, - num_equal::Int, - primal_model::Lux.Chain, - dual_model::Lux.Chain, - train_state_primal::Lux.Training.TrainState, - train_state_dual::Lux.Training.TrainState, - data; - training_step_loop_primal::TrainingStepLoop = _default_primal_loop(bm), - training_step_loop_dual::TrainingStepLoop = _default_dual_loop(num_equal), - stopping_criteria::Vector{F} = [ - (iter, primal_model, dual_model, tr_st_primal, tr_st_dual, hpm_primal) -> - iter >= 100 ? true : false, - ], -) where {F<:Function} - iter = 1 - while all( - stopping_criterion( - iter, - primal_model, - dual_model, - train_state_primal, - train_state_dual, - training_step_loop_primal.hyperparameters, - training_step_loop_dual.hyperparameters, - ) for stopping_criterion in stopping_criteria - ) - L2OALM_epoch!( - primal_model, - train_state_primal, - dual_model, - train_state_dual, - training_step_loop_primal, - training_step_loop_dual, - data, - ) - iter += 1 - end + # Update trainer state + update_trainer!(method, trainer, current_state_primal, current_state_dual) return end diff --git a/test/runtests.jl b/test/runtests.jl index 7d59944..ef25104 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -99,57 +99,26 @@ include("power.jl") data = DataLoader((Θ_train); batchsize = batch_size, shuffle = true) .|> dev_gpu - function validation_testset( - iter, - primal_model, - dual_model, - train_state_primal, - train_state_dual, - hpm_primal, - hpm_dual; - max_dual = 1e6, - ) - ρ = hpm_primal[:ρ] - X̂_test, _ = primal_model( - Θ_test, - train_state_primal.parameters, - train_state_primal.states, - ) - objs_test = BNK.objective!(bm_test, X̂_test, Θ_test) - Vc_test, Vb_test = BNK.all_violations!(bm_test, X̂_test, Θ_test) - gh_test = BNK.constraints!(bm_test, X̂_test, Θ_test) - dual_hat, _ = - dual_model(Θ_test, train_state_dual.parameters, train_state_dual.states) - # Separate bound and equality constraints - gh_bound = gh_test[1:end-num_equal, :] - gh_equal = gh_test[end-num_equal+1:end, :] - dual_hat_bound = dual_hat[1:end-num_equal, :] - dual_hat_equal = dual_hat[end-num_equal+1:end, :] - - # Target for dual variables - dual_target = vcat( - min.(max.(dual_hat_bound + ρ .* gh_bound, 0), max_dual), - min.(dual_hat_equal + ρ .* gh_equal, max_dual), - ) - - dual_loss = mean((dual_hat .- dual_target) .^ 2) - - @info "Validation Testset: Iteration $iter" mean(objs_test) mean(Vc_test) mean( - Vb_test, - ) dual_loss - return iter >= 100 ? true : false + method = ALMMethod(; batch_model=bm_train, num_equal=num_equal) + trainer = ALMTrainer(primal_model, train_state_primal, dual_model, train_state_dual) + method_test = ALMMethod(; batch_model=bm_test, num_equal=num_equal) + test_primal_loss = primal_loss(method_test) + test_dual_loss = dual_loss(method_test) + + _, prev_primal_loss_val, _, _ = test_primal_loss(primal_model, ps_primal, st_primal, (Θ_test, trainer)) + + for iter in 1:100 + single_train_step!(method, trainer, data) + # Log + _, primal_loss_val, stats_primal, train_state_primal = test_primal_loss(primal_model, ps_primal, st_primal, (Θ_test, trainer)) + _, dual_loss_val, stats_dual, train_state_dual = test_dual_loss(dual_model, ps_dual, st_dual, (Θ_test, trainer)) + + @info "Validation Testset: Iteration $iter" primal_loss_val stats_primal.max_violation stats_primal.mean_violations stats_primal.mean_objs dual_loss_val end + # Check that the primal loss is decreasing + @test primal_loss_val < prev_primal_loss_val - L2OALM_train!( - bm_train, - num_equal, - primal_model, - dual_model, - train_state_primal, - train_state_dual, - data, - stopping_criteria = [validation_testset], - ) + return end @testset "Penalty Training" begin