diff --git a/Project.toml b/Project.toml index 41d83ca..f3699aa 100644 --- a/Project.toml +++ b/Project.toml @@ -3,11 +3,30 @@ uuid = "f31bfc7b-7b5d-4cc3-b76b-1af281ce159d" authors = ["Andrew and contributors"] version = "1.0.0-DEV" +[deps] +BatchNLPKernels = "7145f916-0e30-4c9d-93a2-b32b6056125d" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +ExaModels = "1037b233-b668-4ce9-9b63-f9f681f55dd2" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + [compat] julia = "1.6.7" [extras] +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +PGLib = "07a8691f-3d11-4330-951b-3c50f98338be" +PowerModels = "c36e90e8-916a-50a6-bd94-075b64ef4655" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[sources.BatchNLPKernels] +url = "https://github.com/klamike/BatchNLPKernels.jl" + [targets] -test = ["Test"] +test = ["Test", "PowerModels", "PGLib", "Random", "MLUtils", "KernelAbstractions", "GPUArraysCore"] diff --git a/README.md b/README.md index a4d73fd..a9fa2de 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ Learning To Optimize using the Augmented Lagrangian Primal-Dual Method. -[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://andrewrosemberg.github.io/L2OALM.jl/stable/) -[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://andrewrosemberg.github.io/L2OALM.jl/dev/) -[![Build Status](https://github.com/andrewrosemberg/L2OALM.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/andrewrosemberg/L2OALM.jl/actions/workflows/CI.yml?query=branch%3Amain) -[![Coverage](https://codecov.io/gh/andrewrosemberg/L2OALM.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/andrewrosemberg/L2OALM.jl) +[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://LearningToOptimize.github.io/L2OALM.jl/stable/) +[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://LearningToOptimize.github.io/L2OALM.jl/dev/) +[![Build Status](https://github.com/LearningToOptimize/L2OALM.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/LearningToOptimize/L2OALM.jl/actions/workflows/CI.yml?query=branch%3Amain) +[![Coverage](https://codecov.io/gh/LearningToOptimize/L2OALM.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/LearningToOptimize/L2OALM.jl) diff --git a/docs/make.jl b/docs/make.jl index e8c1d33..99f1798 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,23 +1,18 @@ using L2OALM using Documenter -DocMeta.setdocmeta!(L2OALM, :DocTestSetup, :(using L2OALM); recursive=true) +DocMeta.setdocmeta!(L2OALM, :DocTestSetup, :(using L2OALM); recursive = true) makedocs(; - modules=[L2OALM], - authors="Andrew and contributors", - sitename="L2OALM.jl", - format=Documenter.HTML(; - canonical="https://andrewrosemberg.github.io/L2OALM.jl", - edit_link="main", - assets=String[], + modules = [L2OALM], + authors = "Andrew and contributors", + sitename = "L2OALM.jl", + format = Documenter.HTML(; + canonical = "https://LearningToOptimize.github.io/L2OALM.jl", + edit_link = "main", + assets = String[], ), - pages=[ - "Home" => "index.md", - ], + pages = ["Home" => "index.md"], ) -deploydocs(; - repo="github.com/andrewrosemberg/L2OALM.jl", - devbranch="main", -) +deploydocs(; repo = "github.com/LearningToOptimize/L2OALM.jl", devbranch = "main") diff --git a/docs/src/index.md b/docs/src/index.md index 03d7322..d71729e 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -4,7 +4,7 @@ CurrentModule = L2OALM # L2OALM -Documentation for [L2OALM](https://github.com/andrewrosemberg/L2OALM.jl). +Documentation for [L2OALM](https://github.com/LearningToOptimize/L2OALM.jl). ```@index ``` diff --git a/src/L2OALM.jl b/src/L2OALM.jl index 33126a5..ef6836c 100644 --- a/src/L2OALM.jl +++ b/src/L2OALM.jl @@ -1,5 +1,373 @@ module L2OALM -# Write your package code here. +using BatchNLPKernels +using ExaModels + +using Lux +using LuxCUDA +using Lux.Training +using CUDA +using Statistics +using ChainRules: @ignore_derivatives + +export ALMMethod, + ALMTrainer, dual_loss, primal_loss, single_train_step! + +""" + AbstractL2OMethod + +Abstract type for Learning to Optimize (L2O) methods. +""" +abstract type AbstractL2OMethod end + +""" + AbstractPrimalDualMethod + +Abstract type for Primal-Dual Learning to Optimize (L2O) methods. +""" +abstract type AbstractPrimalDualMethod <: AbstractL2OMethod end + +""" + ALMMethod{T<:Real} <: AbstractPrimalDualMethod + +Augmented Lagrangian Method (ALM) for primal-dual Learning to Optimize (L2O) methods. +""" +struct ALMMethod{T<:Real} <: AbstractPrimalDualMethod + batch_model::BatchModel + max_dual::T + ρmax::T + τ::T + α::T + num_equal::Int # TODO: There should be a way to get this from the batch model + + function ALMMethod(batch_model::BatchModel, max_dual::T, ρmax::T, τ::T, α::T, num_equal::Int) where {T<:Real} + new{T}(batch_model, max_dual, ρmax, τ, α, num_equal) + end +end + +""" + ALMMethod(; batch_model::BatchModel, num_equal::Int, max_dual::T = 1e6, ρmax::T = 1e6, τ::T = 0.8, α::T = 10.0) + +Constructor for the Augmented Lagrangian Method (ALM) for primal-dual Learning to Optimize (L2O) methods. +This function initializes the ALM method with a batch model, maximum dual variable values, maximum learning rate `ρ`, +threshold for the parameter updater `τ`, meta learning rate `α`, and the number of equality constraints. +""" +function ALMMethod(; + batch_model::BatchModel, + num_equal::Int, + max_dual::T = 1e6, + ρmax::T = 1e6, + τ::T = 0.8, + α::T = 10.0, +) where {T<:Real} + return ALMMethod(batch_model, max_dual, ρmax, τ, α, num_equal) +end + +""" + AbstractL2OTrainer + +An abstract type for a structure that holds the training state for Learning to Optimize (L2O) methods. +""" +abstract type AbstractL2OTrainer end + +""" + AbstractPrimalDualTrainer + +An abstract type for a structure that holds the training state for Primal-Dual Learning to Optimize (L2O) methods. +""" +abstract type AbstractPrimalDualTrainer <: AbstractL2OTrainer end + +""" + ALMTrainer{T<:Real} <: AbstractPrimalDualTrainer + +A structure that holds the training state for the Augmented Lagrangian Method (ALM) for primal-dual Learning to Optimize (L2O) methods. +""" +mutable struct ALMTrainer{T<:Real} <: AbstractPrimalDualTrainer + primal_model::Lux.Chain + primal_training_state::Lux.Training.TrainState + dual_model::Lux.Chain + dual_training_state::Lux.Training.TrainState + ρ::T + prev_dual_training_state::Lux.Training.TrainState + max_violations::T + mean_violations::T + mean_objs::T + total_loss::T + dual_loss::T + + function ALMTrainer( + primal_model::Lux.Chain, + primal_training_state::Lux.Training.TrainState, + dual_model::Lux.Chain, + dual_training_state::Lux.Training.TrainState, + ρ::T = 1.0, + prev_dual_training_state::Lux.Training.TrainState = deepcopy(dual_training_state), + max_violations::T = Inf, + mean_violations::T = Inf, + mean_objs::T = Inf, + total_loss::T = Inf, + dual_loss::T = Inf, + ) where {T<:Real} + new{T}( + primal_model, primal_training_state, dual_model, dual_training_state, ρ, prev_dual_training_state, + max_violations, mean_violations, mean_objs, total_loss, dual_loss + ) + end +end + +function ALMTrainer(; + primal_model::Lux.Chain, + primal_training_state::Lux.Training.TrainState, + dual_model::Lux.Chain, + dual_training_state::Lux.Training.TrainState, + ρ::T = 1.0, +) where {T<:Real} + return ALMTrainer{T}( + primal_model, primal_training_state, dual_model, dual_training_state, ρ, + ) +end + +""" + dual_loss(method::ALMMethod) + +Returns a function that computes the MSE loss for the (dual-)model predicting lagrangian dual variables +from constraint evaluations and the last dual predictions. +Target is calculated using the augmented lagrangian method. +Target dual variables are clipped from zero to `max_dual`. +""" +function dual_loss(method::ALMMethod) + bm = method.batch_model + max_dual = method.max_dual + num_equal = method.num_equal + return (dual_model, ps_dual, st_dual, data) -> begin + Θ, trainer = data + ρ = trainer.ρ + prev_dual_training_state = trainer.prev_dual_training_state + primal_training_state = trainer.primal_training_state + + # Get current dual predictions + dual_hat, st_dual_new = dual_model(Θ, ps_dual, st_dual) + + # Get previous dual predictions + dual_hat_k, _ = @ignore_derivatives dual_model(Θ, prev_dual_training_state.parameters, prev_dual_training_state.states) + + # Separate bound and equality constraints + # # Forward pass for primal model + X̂, _ = trainer.primal_model(θ, primal_training_state.parameters, primal_training_state.states) + gh = constraints!(bm, X̂, Θ) + + @ignore_derivatives gh = trainer.constraints(Θ, trainer.ps_constraints, trainer.st_constraints) + gh_bound = gh[1:end-num_equal, :] + gh_equal = gh[end-num_equal+1:end, :] + dual_hat_bound = dual_hat_k[1:end-num_equal, :] + dual_hat_equal = dual_hat_k[end-num_equal+1:end, :] + + # Target for dual variables + dual_target = vcat( + min.(max.(dual_hat_bound + ρ .* gh_bound, 0), max_dual), + min.(dual_hat_equal + ρ .* gh_equal, max_dual), + ) + + loss = mean((dual_hat .- dual_target) .^ 2) + return loss, st_dual_new, (dual_loss = loss,) + end +end + +""" + primal_loss(method::ALMMethod) + +Returns a function that computes the augmented lagrangian primal loss +from current dual predictions for the batch model `bm` under parameters `Θ`. +""" +function primal_loss(method::ALMMethod) + bm = method.batch_model + return (model, ps, st, data) -> begin + Θ, trainer = data + ρ = trainer.ρ + num_s = size(Θ, 2) + + # Forward pass for prediction + X̂, st_new = model(Θ, ps, st) + + # Get current dual predictions + dual_hat, _ = @ignore_derivatives trainer.dual_model(Θ, trainer.dual_training_state.parameters, trainer.dual_training_state.states) + + # Calculate violations and objectives + objs = BNK.objective!(bm, X̂, Θ) + # gh = constraints!(bm, X̂, Θ) + Vc, Vb = BNK.all_violations!(bm, X̂, Θ) + V = vcat(Vb, Vc) + total_loss = ( + sum(abs.(dual_hat .* V)) / num_s + ρ / 2 * sum((V) .^ 2) / num_s + mean(objs) + ) + + return total_loss, + st_new, + ( + total_loss = total_loss, + mean_violations = mean(V), + max_violation = maximum(V), + mean_objs = mean(objs), + ) + end +end + +""" + reconcile_primal(::AbstractL2OTrainer, batch_states::Vector{NamedTuple}) + +Reconciles the state of the primal model after processing a batch of data. +This function computes the maximum violation, mean violations, mean objectives, and total loss +from the batch states. +""" +function reconcile_primal(::AbstractL2OTrainer, batch_states::Vector{NamedTuple}) + max_violation = maximum([s.max_violation for s in batch_states]) + mean_violations = mean([s.mean_violations for s in batch_states]) + mean_objs = mean([s.mean_objs for s in batch_states]) + mean_loss = mean([s.total_loss for s in batch_states]) + return (; + max_violation = max_violation, + mean_violations = mean_violations, + mean_objs = mean_objs, + total_loss = mean_loss, + ) +end + +""" + reconcile_dual(::AbstractL2OTrainer, batch_states::Vector{NamedTuple}) + +Reconciles the state of the dual model after processing a batch of data. +This function computes the mean dual loss from the batch states. +""" +function reconcile_dual(::AbstractL2OTrainer, batch_states::Vector{NamedTuple}) + return (dual_loss = mean([s.dual_loss for s in batch_states]),) +end + +""" + update_trainer!(method::ALMMethod, trainer::ALMTrainer, + primal_state::NamedTuple, dual_state::NamedTuple) + +Update the hyperparameters and states in the ALM algorithm. +This function increases ρ by a factor of α if the new maximum violation exceeds τ times the previous maximum violation. +""" +function update_trainer!( + method::ALMMethod, + trainer::ALMTrainer, + primal_state::NamedTuple, + dual_state::NamedTuple, +) + # Update primal state + new_max_violations = primal_state.max_violations + trainer.mean_violations = primal_state.mean_violations + trainer.mean_objs = primal_state.mean_objs + trainer.total_loss = primal_state.total_loss + # Update ρ if necessary + if new_max_violations > method.τ * trainer.max_violations + trainer.ρ = min(method.ρmax, trainer.ρ * method.α) + end + trainer.max_violations = new_max_violations + + # Update dual state + trainer.dual_loss = dual_state.dual_loss + trainer.prev_dual_training_state = deepcopy(trainer.dual_training_state) + + return +end + +""" + _stopping_criterion(iter::Int, method::AbstractL2OMethod, trainer::AbstractL2OTrainer) + +Default stopping criterion for the L2O methods. +This function checks if the number of iterations has reached a predefined limit (100). +""" +function _stopping_criterion(iter::Int, ::M, ::N) where {M<:AbstractL2OMethod, N<:AbstractL2OTrainer} + return iter >= 100 ? true : false +end + +""" + primal_stopping_criterion(iter::Int, method::AbstractL2OMethod, trainer::AbstractL2OTrainer) + +Default stopping criterion for primal learning methods. +This function checks if the number of iterations has reached a predefined limit (100). +""" +function primal_stopping_criterion(iter::Int, method::M, trainer::N) where {M<:AbstractL2OMethod, N<:AbstractL2OTrainer} + return _stopping_criterion(iter, method, trainer) +end + +""" + dual_stopping_criterion(iter::Int, method::AbstractL2OMethod, trainer::AbstractL2OTrainer) + +Default stopping criterion for dual learning methods. +This function checks if the number of iterations has reached a predefined limit (100). +""" +function dual_stopping_criterion(iter::Int, method::M, trainer::N) where {M<:AbstractL2OMethod, N<:AbstractL2OTrainer} + return _stopping_criterion(iter, method, trainer) +end + +""" + single_train_step!( + method::AbstractPrimalDualMethod, + trainer::AbstractPrimalDualTrainer, + data, +) + +Performs a single training step for the primal-dual method. +This function loops through the primal learning method until the stopping criterion is met +with the dual model fixed, then the inverse is done with the dual learning method and finally +updates the trainer state. +""" +function single_train_step!( + method::AbstractPrimalDualMethod, + trainer::AbstractPrimalDualTrainer, + data, +) + iter_primal = 1 + iter_dual = 1 + num_batches = length(data) + current_state_primal = (;) + current_state_dual = (;) + _primal_loss = primal_loss(method) + _dual_loss = dual_loss(method) + train_state_primal = trainer.primal_training_state + train_state_dual = trainer.dual_training_state + + # primal loop + while primal_stopping_criterion(iter_primal, method, trainer) + current_states_primal = Vector{NamedTuple}(undef, num_batches) + iter_batch = 1 + for (θ) in data + _, loss_val, stats, train_state_primal = Training.single_train_step!( + AutoZygote(), # AD backend + _primal_loss, # Loss function + (θ, trainer), # Data + train_state_primal, # Training state + ) + current_states_primal[iter_batch] = stats + iter_batch += 1 + end + current_state_primal = reconcile_primal(trainer, current_states_primal) + iter_primal += 1 + end + + # dual loop + while dual_stopping_criterion(iter_dual, method, trainer) + current_states_dual = Vector{NamedTuple}(undef, num_batches) + iter_batch = 1 + for (θ) in data + _, loss_val, stats, train_state_dual = Training.single_train_step!( + AutoZygote(), # AD backend + _dual_loss, # Loss function + (θ, trainer), # Data + train_state_dual, # Training state + ) + current_states_dual[iter_batch] = stats + iter_batch += 1 + end + current_state_dual = reconcile_dual(trainer, current_states_dual) + iter_dual += 1 + end + # Update trainer state + update_trainer!(method, trainer, current_state_primal, current_state_dual) + return +end end diff --git a/test/power.jl b/test/power.jl new file mode 100644 index 0000000..6644468 --- /dev/null +++ b/test/power.jl @@ -0,0 +1,212 @@ +function _get_case_file(filename::String) + isfile(filename) && return filename + + cached = joinpath(PGLib.PGLib_opf, filename) + !isfile(cached) && error("File $filename not found in PGLib/pglib-opf") + return cached +end +function _build_power_ref(filename) + path = _get_case_file(filename) + data = PowerModels.parse_file(path) + PowerModels.standardize_cost_terms!(data, order = 2) + PowerModels.calc_thermal_limits!(data) + return PowerModels.build_ref(data)[:it][:pm][:nw][0] +end + + +_convert_array(data::N, backend) where {names,N<:NamedTuple{names}} = + NamedTuple{names}(ExaModels.convert_array(d, backend) for d in data) +function _parse_ac_data_raw(filename; T = Float64) + ref = _build_power_ref(filename) # FIXME: only parse once + arcdict = Dict(a => k for (k, a) in enumerate(ref[:arcs])) + busdict = Dict(k => i for (i, (k, _)) in enumerate(ref[:bus])) + gendict = Dict(k => i for (i, (k, _)) in enumerate(ref[:gen])) + branchdict = Dict(k => i for (i, (k, _)) in enumerate(ref[:branch])) + return ( + bus = [ + begin + loads = [ref[:load][l] for l in ref[:bus_loads][k]] + shunts = [ref[:shunt][s] for s in ref[:bus_shunts][k]] + pd = T(sum(load["pd"] for load in loads; init = 0.0)) + qd = T(sum(load["qd"] for load in loads; init = 0.0)) + gs = T(sum(shunt["gs"] for shunt in shunts; init = 0.0)) + bs = T(sum(shunt["bs"] for shunt in shunts; init = 0.0)) + (i = busdict[k], pd = pd, gs = gs, qd = qd, bs = bs) + end for (k, _) in ref[:bus] + ], + gen = [ + ( + i = gendict[k], + cost1 = T(v["cost"][1]), + cost2 = T(v["cost"][2]), + cost3 = T(v["cost"][3]), + bus = busdict[v["gen_bus"]], + ) for (k, v) in ref[:gen] + ], + arc = [ + (i = k, rate_a = T(ref[:branch][l]["rate_a"]), bus = busdict[i]) for + (k, (l, i, _)) in enumerate(ref[:arcs]) + ], + branch = [ + begin + branch = branch_raw + f_idx = arcdict[i, branch["f_bus"], branch["t_bus"]] + t_idx = arcdict[i, branch["t_bus"], branch["f_bus"]] + + g, b = PowerModels.calc_branch_y(branch) + tr, ti = PowerModels.calc_branch_t(branch) + ttm = tr^2 + ti^2 + + g_fr = branch["g_fr"] + b_fr = branch["b_fr"] + g_to = branch["g_to"] + b_to = branch["b_to"] + + ( + i = branchdict[i], + j = 1, + f_idx = f_idx, + t_idx = t_idx, + f_bus = busdict[branch["f_bus"]], + t_bus = busdict[branch["t_bus"]], + c1 = T((-g * tr - b * ti) / ttm), + c2 = T((-b * tr + g * ti) / ttm), + c3 = T((-g * tr + b * ti) / ttm), + c4 = T((-b * tr - g * ti) / ttm), + c5 = T((g + g_fr) / ttm), + c6 = T((b + b_fr) / ttm), + c7 = T((g + g_to)), + c8 = T((b + b_to)), + rate_a_sq = T(branch["rate_a"]^2), + ) + end for (i, branch_raw) in ref[:branch] + ], + ref_buses = [busdict[i] for (i, _) in ref[:ref_buses]], + vmax = [T(v["vmax"]) for (_, v) in ref[:bus]], + vmin = [T(v["vmin"]) for (_, v) in ref[:bus]], + pmax = [T(v["pmax"]) for (_, v) in ref[:gen]], + pmin = [T(v["pmin"]) for (_, v) in ref[:gen]], + qmax = [T(v["qmax"]) for (_, v) in ref[:gen]], + qmin = [T(v["qmin"]) for (_, v) in ref[:gen]], + rate_a = [T(ref[:branch][l]["rate_a"]) for (l, _, _) in ref[:arcs]], + angmax = [T(b["angmax"]) for (_, b) in ref[:branch]], + angmin = [T(b["angmin"]) for (_, b) in ref[:branch]], + ) +end + +_parse_ac_data(filename) = _parse_ac_data_raw(filename) +function _parse_ac_data(filename, backend; T = Float64) + _convert_array(_parse_ac_data_raw(filename, T = T), backend) +end + +# Parametric version + +function create_parametric_ac_power_model( + filename::String = "pglib_opf_case14_ieee.m"; + prod::Bool = true, + backend = OpenCLBackend(), + T = Float64, + kwargs..., +) + data = _parse_ac_data(filename, backend, T = T) + c = ExaCore(T; backend = backend) + + va = variable(c, length(data.bus);) + vm = variable( + c, + length(data.bus); + start = fill!(similar(data.bus, T), 1.0), + lvar = data.vmin, + uvar = data.vmax, + ) + + pg = variable(c, length(data.gen); lvar = data.pmin, uvar = data.pmax) + qg = variable(c, length(data.gen); lvar = data.qmin, uvar = data.qmax) + + @allowscalar load_multiplier = parameter(c, [1.0 for b in data.bus]) + + p = variable(c, length(data.arc); lvar = -data.rate_a, uvar = data.rate_a) + q = variable(c, length(data.arc); lvar = -data.rate_a, uvar = data.rate_a) + + o = objective(c, g.cost1 * pg[g.i]^2 + g.cost2 * pg[g.i] + g.cost3 for g in data.gen) + + # Reference bus angle ------------------------------------------------------ + c1 = constraint(c, va[i] for i in data.ref_buses) + + # Branch power-flow equations --------------------------------------------- + constraint( + c, + ( + p[b.f_idx] - b.c5 * vm[b.f_bus]^2 - + b.c3 * (vm[b.f_bus] * vm[b.t_bus] * cos(va[b.f_bus] - va[b.t_bus])) - + b.c4 * (vm[b.f_bus] * vm[b.t_bus] * sin(va[b.f_bus] - va[b.t_bus])) for + b in data.branch + ), + ) + + constraint( + c, + ( + q[b.f_idx] + + b.c6 * vm[b.f_bus]^2 + + b.c4 * (vm[b.f_bus] * vm[b.t_bus] * cos(va[b.f_bus] - va[b.t_bus])) - + b.c3 * (vm[b.f_bus] * vm[b.t_bus] * sin(va[b.f_bus] - va[b.t_bus])) for + b in data.branch + ), + ) + + constraint( + c, + ( + p[b.t_idx] - b.c7 * vm[b.t_bus]^2 - + b.c1 * (vm[b.t_bus] * vm[b.f_bus] * cos(va[b.t_bus] - va[b.f_bus])) - + b.c2 * (vm[b.t_bus] * vm[b.f_bus] * sin(va[b.t_bus] - va[b.f_bus])) for + b in data.branch + ), + ) + + constraint( + c, + ( + q[b.t_idx] + + b.c8 * vm[b.t_bus]^2 + + b.c2 * (vm[b.t_bus] * vm[b.f_bus] * cos(va[b.t_bus] - va[b.f_bus])) - + b.c1 * (vm[b.t_bus] * vm[b.f_bus] * sin(va[b.t_bus] - va[b.f_bus])) for + b in data.branch + ), + ) + + # Angle difference limits -------------------------------------------------- + constraint( + c, + va[b.f_bus] - va[b.t_bus] for b in data.branch; + lcon = data.angmin, + ucon = data.angmax, + ) + + # Apparent power thermal limits ------------------------------------------- + constraint( + c, + p[b.f_idx]^2 + q[b.f_idx]^2 - b.rate_a_sq for b in data.branch; + lcon = fill!(similar(data.branch, Float64, length(data.branch)), -Inf), + ) + constraint( + c, + p[b.t_idx]^2 + q[b.t_idx]^2 - b.rate_a_sq for b in data.branch; + lcon = fill!(similar(data.branch, Float64, length(data.branch)), -Inf), + ) + + # Power balance at each bus ----------------------------------------------- + load_balance_p = + constraint(c, b.pd * load_multiplier[b.i] + b.gs * vm[b.i]^2 for b in data.bus) + load_balance_q = + constraint(c, b.qd * load_multiplier[b.i] - b.bs * vm[b.i]^2 for b in data.bus) + + # Map arc & generator variables into the bus balance equations + constraint!(c, load_balance_p, a.bus => p[a.i] for a in data.arc) + constraint!(c, load_balance_q, a.bus => q[a.i] for a in data.arc) + constraint!(c, load_balance_p, g.bus => -pg[g.i] for g in data.gen) + constraint!(c, load_balance_q, g.bus => -qg[g.i] for g in data.gen) + + return ExaModel(c; prod = prod), length(data.bus), length(data.gen), length(data.arc) +end diff --git a/test/runtests.jl b/test/runtests.jl index c13f2e1..ef25104 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,138 @@ using L2OALM using Test +using Lux +using Optimisers +using MLUtils +using KernelAbstractions +using ExaModels +using BatchNLPKernels +using CUDA + +using PowerModels +PowerModels.silence() +using PGLib + +using Random +using Statistics + +import GPUArraysCore: @allowscalar + +const BNK = BatchNLPKernels + +include("power.jl") @testset "L2OALM.jl" begin - # Write your tests here. + function feed_forward_builder( + num_p::Integer, + num_y::Integer, + hidden_layers::AbstractVector{<:Integer}; + activation = relu, + ) + """ + Builds a Chain of Dense layers with Lux + """ + # Combine all layers: input size, hidden sizes, output size + layer_sizes = [num_p; hidden_layers; num_y] + + # Build up a list of Dense layers + dense_layers = Any[] + for i = 1:(length(layer_sizes)-1) + if i < length(layer_sizes) - 1 + # Hidden layers with activation + push!(dense_layers, Dense(layer_sizes[i], layer_sizes[i+1], activation)) + else + # Final layer with no activation + push!(dense_layers, Dense(layer_sizes[i], layer_sizes[i+1])) + end + end + + return Chain(dense_layers...) + end + + function test_alm_training(; + filename = "pglib_opf_case14_ieee.m", + dev_gpu = gpu_device(), + backend = CPU(), + batch_size = 32, + dataset_size = 3200, + rng = Random.default_rng(), + T = Float64, + ) + model, nbus, ngen, blines = + create_parametric_ac_power_model(filename; backend = backend, T = T) + bm_train = BNK.BatchModel(model, batch_size, config = BNK.BatchModelConfig(:full)) + bm_test = BNK.BatchModel(model, dataset_size, config = BNK.BatchModelConfig(:full)) + + nvar = model.meta.nvar + ncon = model.meta.ncon + nθ = length(model.θ) + num_equal = nbus * 2 + + Θ_train = randn(T, nθ, dataset_size) |> dev_gpu + Θ_test = randn(T, nθ, dataset_size) |> dev_gpu + + primal_model = feed_forward_builder(nθ, nvar, [320, 320]) + ps_primal, st_primal = Lux.setup(rng, primal_model) + ps_primal = ps_primal |> dev_gpu + st_primal = st_primal |> dev_gpu + + dual_model = feed_forward_builder(nθ, ncon, [320, 320]) + ps_dual, st_dual = Lux.setup(rng, dual_model) + ps_dual = ps_dual |> dev_gpu + st_dual = st_dual |> dev_gpu + + X̂, _ = primal_model(Θ_test, ps_primal, st_primal) + + y = BNK.objective!(bm_test, X̂, Θ_test) + + @test length(y) == dataset_size + Vc, Vb = BNK.all_violations!(bm_test, X̂, Θ_test) + @test size(Vc) == (ncon, dataset_size) + @test size(Vb) == (nvar, dataset_size) + + # lagrangian_prev = sum(y) + 1000 * sum(Vc) + 1000 * sum(Vb) + + train_state_primal = + Training.TrainState(primal_model, ps_primal, st_primal, Optimisers.Adam(1e-5)) + train_state_dual = + Training.TrainState(dual_model, ps_dual, st_dual, Optimisers.Adam(1e-5)) + + data = DataLoader((Θ_train); batchsize = batch_size, shuffle = true) .|> dev_gpu + + method = ALMMethod(; batch_model=bm_train, num_equal=num_equal) + trainer = ALMTrainer(primal_model, train_state_primal, dual_model, train_state_dual) + method_test = ALMMethod(; batch_model=bm_test, num_equal=num_equal) + test_primal_loss = primal_loss(method_test) + test_dual_loss = dual_loss(method_test) + + _, prev_primal_loss_val, _, _ = test_primal_loss(primal_model, ps_primal, st_primal, (Θ_test, trainer)) + + for iter in 1:100 + single_train_step!(method, trainer, data) + # Log + _, primal_loss_val, stats_primal, train_state_primal = test_primal_loss(primal_model, ps_primal, st_primal, (Θ_test, trainer)) + _, dual_loss_val, stats_dual, train_state_dual = test_dual_loss(dual_model, ps_dual, st_dual, (Θ_test, trainer)) + + @info "Validation Testset: Iteration $iter" primal_loss_val stats_primal.max_violation stats_primal.mean_violations stats_primal.mean_objs dual_loss_val + end + # Check that the primal loss is decreasing + @test primal_loss_val < prev_primal_loss_val + + return + end + + @testset "Penalty Training" begin + backend, dev = if haskey(ENV, "BNK_TEST_CUDA") + CUDABackend(), gpu_device() + else + CPU(), cpu_device() + end + + test_alm_training(; + filename = "pglib_opf_case14_ieee.m", + dev_gpu = dev, + backend = backend, + T = Float32, + ) + end end