diff --git a/src/ADMM.jl b/src/ADMM.jl index 1de372d..3aa8b0b 100644 --- a/src/ADMM.jl +++ b/src/ADMM.jl @@ -25,7 +25,7 @@ mutable struct ADMMState{rT <: Real, rvecT <: AbstractVector{rT}, vecT <: Union{ z::Vector{vecT} zᵒˡᵈ::Vector{vecT} u::Vector{vecT} - uᵒˡᵈ::Vector{vecT} + uᵒˡᵈ::Vector{vecT} # other paremters ρ::rvecT iteration::Int64 @@ -156,7 +156,7 @@ function init!(solver::ADMM, state::ADMMState{rT, rvecT, vecT}, b::otherT; kwarg state = ADMMState(β, β_y, x, xᵒˡᵈ, z, zᵒˡᵈ, u, uᵒˡᵈ, state.ρ, state.iteration, cgStateVars, state.rᵏ, state.sᵏ, state.ɛᵖʳⁱ, state.ɛᵈᵘᵃ, state.σᵃᵇˢ, state.Δ, state.absTol, state.relTol, state.tolInner) - + solver.state = state init!(solver, state, b; kwargs...) end @@ -244,16 +244,37 @@ function iterate(solver::ADMM, state::ADMMState) state.u[i] .-= state.z[i] # update convergence criteria (one for each constraint) - state.rᵏ[i] = norm(solver.regTrafo[i] * state.x - state.z[i]) # primal residual (x-z) - state.sᵏ[i] = norm(state.ρ[i] * adjoint(solver.regTrafo[i]) * (state.z[i] .- state.zᵒˡᵈ[i])) # dual residual (concerning f(x)) + # The following commented lines are a readable calculation of the convergence criteria. However, they allocate a substantial amount of memory, and we use a less readable, but less allocating code that hijacks the variables xᵒˡᵈ and zᵒˡᵈ as they are unused at this stage of the iteration. + # state.rᵏ[i] = norm(solver.regTrafo[i] * state.x - state.z[i]) # primal residual (x-z) + # state.sᵏ[i] = norm(state.ρ[i] * adjoint(solver.regTrafo[i]) * (state.z[i] .- state.zᵒˡᵈ[i])) # dual residual (concerning f(x)) + + # state.ɛᵖʳⁱ[i] = max(norm(solver.regTrafo[i] * state.x), norm(state.z[i])) + # state.ɛᵈᵘᵃ[i] = norm(state.ρ[i] * adjoint(solver.regTrafo[i]) * state.u[i]) - state.ɛᵖʳⁱ[i] = max(norm(solver.regTrafo[i] * state.x), norm(state.z[i])) - state.ɛᵈᵘᵃ[i] = norm(state.ρ[i] * adjoint(solver.regTrafo[i]) * state.u[i]) + # Δᵒˡᵈ = state.Δ[i] + # state.Δ[i] = norm(state.x .- state.xᵒˡᵈ ) + + # norm(state.z[i] .- state.zᵒˡᵈ[i]) + + # norm(state.u[i] .- state.uᵒˡᵈ[i]) + + state.xᵒˡᵈ .= state.x .- state.xᵒˡᵈ + state.zᵒˡᵈ[i] .= state.z[i] .- state.zᵒˡᵈ[i] + state.uᵒˡᵈ[i] .= state.u[i] .- state.uᵒˡᵈ[i] Δᵒˡᵈ = state.Δ[i] - state.Δ[i] = norm(state.x .- state.xᵒˡᵈ ) + - norm(state.z[i] .- state.zᵒˡᵈ[i]) + - norm(state.u[i] .- state.uᵒˡᵈ[i]) + state.Δ[i] = norm(state.xᵒˡᵈ) + norm(state.zᵒˡᵈ[i]) + norm(state.uᵒˡᵈ[i]) + + mul!(state.xᵒˡᵈ, adjoint(solver.regTrafo[i]), state.zᵒˡᵈ[i]) + state.sᵏ[i] = state.ρ[i] * norm(state.xᵒˡᵈ) # dual residual (concerning f(x)) + + mul!(state.zᵒˡᵈ[i], solver.regTrafo[i], state.x) + state.ɛᵖʳⁱ[i] = max(norm(state.zᵒˡᵈ[i]), norm(state.z[i])) + + state.zᵒˡᵈ[i] .-= state.z[i] + state.rᵏ[i] = norm(state.zᵒˡᵈ[i]) # primal residual (x-z) + + mul!(state.xᵒˡᵈ, adjoint(solver.regTrafo[i]), state.u[i]) + state.ɛᵈᵘᵃ[i] = state.ρ[i] * norm(state.xᵒˡᵈ) + if (solver.vary_ρ == :balance && state.rᵏ[i]/state.ɛᵖʳⁱ[i] > 10state.sᵏ[i]/state.ɛᵈᵘᵃ[i]) || # adapt ρ according to Boyd et al. (solver.vary_ρ == :PnP && state.Δ[i]/Δᵒˡᵈ > 0.9) # adapt ρ according to Chang et al.