Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions src/ADMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ mutable struct ADMMState{rT <: Real, rvecT <: AbstractVector{rT}, vecT <: Union{
z::Vector{vecT}
zᵒˡᵈ::Vector{vecT}
u::Vector{vecT}
uᵒˡᵈ::Vector{vecT}
uᵒˡᵈ::Vector{vecT}
# other paremters
ρ::rvecT
iteration::Int64
Expand Down Expand Up @@ -156,7 +156,7 @@ function init!(solver::ADMM, state::ADMMState{rT, rvecT, vecT}, b::otherT; kwarg

state = ADMMState(β, β_y, x, xᵒˡᵈ, z, zᵒˡᵈ, u, uᵒˡᵈ, state.ρ, state.iteration, cgStateVars,
state.rᵏ, state.sᵏ, state.ɛᵖʳⁱ, state.ɛᵈᵘᵃ, state.σᵃᵇˢ, state.Δ, state.absTol, state.relTol, state.tolInner)

solver.state = state
init!(solver, state, b; kwargs...)
end
Expand Down Expand Up @@ -244,16 +244,37 @@ function iterate(solver::ADMM, state::ADMMState)
state.u[i] .-= state.z[i]

# update convergence criteria (one for each constraint)
state.rᵏ[i] = norm(solver.regTrafo[i] * state.x - state.z[i]) # primal residual (x-z)
state.sᵏ[i] = norm(state.ρ[i] * adjoint(solver.regTrafo[i]) * (state.z[i] .- state.zᵒˡᵈ[i])) # dual residual (concerning f(x))
# The following commented lines are a readable calculation of the convergence criteria. However, they allocate a substantial amount of memory, and we use a less readable, but less allocating code that hijacks the variables xᵒˡᵈ and zᵒˡᵈ as they are unused at this stage of the iteration.
# state.rᵏ[i] = norm(solver.regTrafo[i] * state.x - state.z[i]) # primal residual (x-z)
# state.sᵏ[i] = norm(state.ρ[i] * adjoint(solver.regTrafo[i]) * (state.z[i] .- state.zᵒˡᵈ[i])) # dual residual (concerning f(x))

# state.ɛᵖʳⁱ[i] = max(norm(solver.regTrafo[i] * state.x), norm(state.z[i]))
# state.ɛᵈᵘᵃ[i] = norm(state.ρ[i] * adjoint(solver.regTrafo[i]) * state.u[i])

state.ɛᵖʳⁱ[i] = max(norm(solver.regTrafo[i] * state.x), norm(state.z[i]))
state.ɛᵈᵘᵃ[i] = norm(state.ρ[i] * adjoint(solver.regTrafo[i]) * state.u[i])
# Δᵒˡᵈ = state.Δ[i]
# state.Δ[i] = norm(state.x .- state.xᵒˡᵈ ) +
# norm(state.z[i] .- state.zᵒˡᵈ[i]) +
# norm(state.u[i] .- state.uᵒˡᵈ[i])

state.xᵒˡᵈ .= state.x .- state.xᵒˡᵈ
state.zᵒˡᵈ[i] .= state.z[i] .- state.zᵒˡᵈ[i]
state.uᵒˡᵈ[i] .= state.u[i] .- state.uᵒˡᵈ[i]

Δᵒˡᵈ = state.Δ[i]
state.Δ[i] = norm(state.x .- state.xᵒˡᵈ ) +
norm(state.z[i] .- state.zᵒˡᵈ[i]) +
norm(state.u[i] .- state.uᵒˡᵈ[i])
state.Δ[i] = norm(state.xᵒˡᵈ) + norm(state.zᵒˡᵈ[i]) + norm(state.uᵒˡᵈ[i])

mul!(state.xᵒˡᵈ, adjoint(solver.regTrafo[i]), state.zᵒˡᵈ[i])
state.sᵏ[i] = state.ρ[i] * norm(state.xᵒˡᵈ) # dual residual (concerning f(x))

mul!(state.zᵒˡᵈ[i], solver.regTrafo[i], state.x)
state.ɛᵖʳⁱ[i] = max(norm(state.zᵒˡᵈ[i]), norm(state.z[i]))

state.zᵒˡᵈ[i] .-= state.z[i]
state.rᵏ[i] = norm(state.zᵒˡᵈ[i]) # primal residual (x-z)

mul!(state.xᵒˡᵈ, adjoint(solver.regTrafo[i]), state.u[i])
state.ɛᵈᵘᵃ[i] = state.ρ[i] * norm(state.xᵒˡᵈ)


if (solver.vary_ρ == :balance && state.rᵏ[i]/state.ɛᵖʳⁱ[i] > 10state.sᵏ[i]/state.ɛᵈᵘᵃ[i]) || # adapt ρ according to Boyd et al.
(solver.vary_ρ == :PnP && state.Δ[i]/Δᵒˡᵈ > 0.9) # adapt ρ according to Chang et al.
Expand Down
Loading