Skip to content

Commit 4efbdf6

Browse files
authored
Batch utils (#12)
* remove bridge test * batch/gpu utils * use batch/gpu utils * test names * batch tests * flatten_y * fix types * latest julia for tests
1 parent b0b5328 commit 4efbdf6

File tree

9 files changed

+275
-50
lines changed

9 files changed

+275
-50
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
fail-fast: false
1919
matrix:
2020
version:
21-
- '1.10'
21+
- '1'
2222
os:
2323
- ubuntu-latest
2424
arch:

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Klamkin", "Michael <[email protected]> and contributors"]
44
version = "1.0.0-DEV"
55

66
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
78
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
89
Dualization = "191a621a-6537-11e9-281d-650236a99e60"
910
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -14,6 +15,7 @@ MathOptSetDistances = "3b969827-a86c-476c-9527-bb6f1a8fbad5"
1415
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1516

1617
[compat]
18+
Adapt = "4.3.0"
1719
Dualization = "=0.7.0"
1820
MathOptSetDistances = "=0.2.11"
1921

src/L2ODLL.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module L2ODLL
22

3+
import Adapt
34
import Dualization
45
import JuMP
56
import LinearAlgebra
@@ -17,6 +18,7 @@ const ADTypes = DI.ADTypes
1718

1819
abstract type AbstractDecomposition end # must have p_ref and y_ref and implement can_decompose
1920

21+
include("batch.jl")
2022
include("layers/generic.jl")
2123
include("layers/bounded.jl")
2224
include("layers/convex_qp.jl")
@@ -178,21 +180,27 @@ function y_shape(cache::DLLCache)
178180
end
179181

180182
"""
181-
flatten_y(y::AbstractVector)
183+
flatten_y(y)
182184
183185
Flatten a vector of `y` variables into a single vector, i.e. Vector{Vector{Float64}} -> Vector{Float64}.
184186
"""
185-
function flatten_y(y::AbstractVector)
187+
function flatten_y(y)
186188
return reduce(vcat, y)
187189
end
188190

189191
"""
190-
unflatten_y(y::AbstractVector, y_shape::AbstractVector{Int})
192+
unflatten_y(y::Vector{T}, y_shape::Vector{Int}) where T
191193
192194
Unflatten a vector of flattened `y` variables into a vector of vectors, i.e. Vector{Float64} -> Vector{Vector{Float64}}.
193195
"""
194-
function unflatten_y(y::AbstractVector, y_shape::AbstractVector{Int})
195-
return [y[start_idx:start_idx + shape - 1] for (start_idx, shape) in enumerate(y_shape)]
196+
function unflatten_y(y::Vector{T}, y_shape::Vector{Int}) where T
197+
result = Vector{T}[]
198+
start_idx = 1
199+
for shape in y_shape
200+
push!(result, y[start_idx:start_idx + shape - 1])
201+
start_idx += shape
202+
end
203+
return result
196204
end
197205

198206
end # module

src/batch.jl

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
abstract type AbstractExprMatrix end
2+
3+
struct AffExprMatrix{V,T} <: AbstractExprMatrix
4+
c::V
5+
c0::T
6+
end
7+
struct QuadExprMatrix{M,V,T} <: AbstractExprMatrix
8+
Q::M
9+
c::V
10+
c0::T
11+
end
12+
struct VecAffExprMatrix{V,T} <: AbstractExprMatrix
13+
A::V
14+
b::T
15+
end
16+
17+
18+
Adapt.@adapt_structure AffExprMatrix
19+
Adapt.@adapt_structure QuadExprMatrix
20+
Adapt.@adapt_structure VecAffExprMatrix
21+
22+
(aem::AffExprMatrix)(x) = aem.c'*x + aem.c0
23+
(qem::QuadExprMatrix)(x) = x'*qem.Q*x + qem.c'*x + qem.c0
24+
(vaem::VecAffExprMatrix)(x) = vaem.A*x + vaem.b
25+
26+
27+
function AffExprMatrix(
28+
aff::Vector{JuMP.GenericAffExpr{T,V}},
29+
v::Vector{V};
30+
backend = nothing
31+
) where {T,V}
32+
n = length(v)
33+
34+
c = zeros(T, n)
35+
36+
for (coeff, vr) in JuMP.linear_terms(aff)
37+
c[vr_to_idx[vr]] = coeff
38+
end
39+
40+
c = _backend_vector(backend)(c)
41+
return AffExprMatrix(c, c0)
42+
end
43+
44+
45+
function QuadExprMatrix(
46+
qexpr::JuMP.GenericQuadExpr{T,V},
47+
v::Vector{V};
48+
backend = nothing
49+
) where {T,V}
50+
quad_terms = JuMP.quad_terms(qexpr)
51+
nq = length(quad_terms)
52+
n = length(v)
53+
54+
vr_to_idx = _vr_to_idx(v)
55+
56+
Qi = Int[]
57+
sizehint!(Qi, nq)
58+
Qj = Int[]
59+
sizehint!(Qj, nq)
60+
Qv = T[]
61+
sizehint!(Qv, nq)
62+
c = zeros(T, n)
63+
c0 = qexpr.aff.constant
64+
65+
for (coeff, vr1, vr2) in quad_terms
66+
push!(Qi, vr_to_idx[vr1])
67+
push!(Qj, vr_to_idx[vr2])
68+
push!(Qv, coeff)
69+
end
70+
for (coeff, vr) in JuMP.linear_terms(qexpr)
71+
c[vr_to_idx[vr]] = coeff
72+
end
73+
74+
Q = _backend_matrix(backend)(Qi, Qj, Qv, n, n)
75+
c = _backend_vector(backend)(c)
76+
return QuadExprMatrix(Q, c, c0)
77+
end
78+
79+
80+
function VecAffExprMatrix(
81+
vaff::Vector{JuMP.GenericAffExpr{T,V}},
82+
v::Vector{V};
83+
backend = nothing
84+
) where {T,V}
85+
m = length(vaff)
86+
n = length(v)
87+
88+
linear_terms = [JuMP.linear_terms(jaff) for jaff in vaff]
89+
nlinear = sum(length.(linear_terms))
90+
91+
vr_to_idx = _vr_to_idx(v)
92+
93+
Ai = Int[]
94+
sizehint!(Ai, nlinear)
95+
Aj = Int[]
96+
sizehint!(Aj, nlinear)
97+
Av = T[]
98+
sizehint!(Av, nlinear)
99+
b = zeros(T, m)
100+
101+
for (i, jaff) in enumerate(vaff)
102+
for (coeff, vr) in linear_terms[i]
103+
if vr v
104+
push!(Ai, i)
105+
push!(Aj, vr_to_idx[vr])
106+
push!(Av, coeff)
107+
else
108+
error("Variable $vr from function $i not found")
109+
end
110+
end
111+
b[i] = jaff.constant
112+
end
113+
114+
A = _backend_matrix(backend)(Ai, Aj, Av, m, n)
115+
b = _backend_vector(backend)(b)
116+
return VecAffExprMatrix(A, b)
117+
end
118+
119+
function _vr_to_idx(v::V) where {V}
120+
vr_to_idx = Dict{eltype(V), Int}()
121+
for (i, vr) in enumerate(v)
122+
vr_to_idx[vr] = i
123+
end
124+
return vr_to_idx
125+
end
126+
127+
function _backend_matrix(::Nothing)
128+
return SparseArrays.sparse
129+
end
130+
131+
function _backend_vector(::Nothing)
132+
return Vector
133+
end

src/layers/bounded.jl

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ function can_decompose(model::JuMP.Model, ::Type{BoundDecomposition})
2727
return false
2828
end
2929

30-
function bounded_builder(decomposition::BoundDecomposition, proj_fn, dual_model::JuMP.Model; completion=:exact, μ=1.0)
30+
function bounded_builder(decomposition::BoundDecomposition, proj_fn, dual_model::JuMP.Model;
31+
completion=:exact, μ=1.0, backend=nothing
32+
)
3133
p_vars = get_p(dual_model, decomposition)
3234
y_vars = get_y_dual(dual_model, decomposition)
3335
zl_vars = only.(get_zl(dual_model, decomposition))
@@ -76,22 +78,24 @@ function bounded_builder(decomposition::BoundDecomposition, proj_fn, dual_model:
7678
error("Invalid completion type: $completion. Must be :exact or :log.")
7779
end
7880

81+
z_fn = VecAffExprMatrix(
82+
zl_plus_zu,
83+
[flatten_y(y_vars); p_vars];
84+
backend=backend
85+
)
86+
obj_fn = QuadExprMatrix(
87+
obj_func,
88+
[flatten_y(y_vars); p_vars; zl_vars; zu_vars];
89+
backend=backend
90+
)
7991
return (y_pred, param_value) -> begin
8092
y_pred_proj = proj_fn(y_pred)
8193

82-
zl_plus_zu_val = JuMP.value.(vr -> _find_and_return_value(vr,
83-
[reduce(vcat, y_vars), p_vars],
84-
[reduce(vcat, y_pred_proj), param_value]),
85-
zl_plus_zu
86-
)
94+
zl_plus_zu_val = z_fn([flatten_y(y_pred_proj); param_value])
8795

8896
zl, zu = complete_zlzu(completer, zl_plus_zu_val)
8997

90-
JuMP.value.(vr -> _find_and_return_value(vr,
91-
[reduce(vcat, y_vars), p_vars, zl_vars, zu_vars],
92-
[reduce(vcat, y_pred_proj), param_value, zl, zu]),
93-
obj_func
94-
)
98+
obj_fn([flatten_y(y_pred_proj); param_value; zl; zu])
9599
end
96100
end
97101

@@ -109,14 +113,15 @@ function complete_zlzu(::ExactBoundedCompletion, zl_plus_zu)
109113
return max.(zl_plus_zu, zero(eltype(zl_plus_zu))), -max.(-zl_plus_zu, zero(eltype(zl_plus_zu)))
110114
end
111115

112-
struct LogBoundedCompletion{T<:Real} <: BoundedCompletion
116+
struct LogBoundedCompletion{V,T} <: BoundedCompletion
113117
μ::T
114-
l::AbstractVector{T}
115-
u::AbstractVector{T}
118+
l::V
119+
u::V
116120
end
117-
function complete_zlzu(c::LogBoundedCompletion, zl_plus_zu)
121+
Adapt.@adapt_structure LogBoundedCompletion
122+
function complete_zlzu(c::LogBoundedCompletion{V,T}, zl_plus_zu) where {V,T}
118123
v = c.μ ./ (c.u - c.l)
119-
w = eltype(zl_plus_zu)(1//2) .* zl_plus_zu
124+
w = T(1//2) .* zl_plus_zu
120125
sqrtv2w2 = hypot.(v, w)
121126
return (
122127
v + w + sqrtv2w2,

src/layers/convex_qp.jl

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ function can_decompose(model::JuMP.Model, ::Type{ConvexQP})
3333
return true
3434
end
3535

36-
function convex_qp_builder(decomposition::ConvexQP, proj_fn, dual_model::JuMP.Model)
36+
function convex_qp_builder(decomposition::ConvexQP, proj_fn, dual_model::JuMP.Model;
37+
backend=nothing
38+
)
3739
p_vars = get_p(dual_model, decomposition)
3840
y_vars = get_y_dual(dual_model, decomposition)
3941
x_vars = get_x(decomposition)
@@ -73,25 +75,28 @@ function convex_qp_builder(decomposition::ConvexQP, proj_fn, dual_model::JuMP.Mo
7375
end
7476
@assert isempty(idx_left) "Some z were not found in the model"
7577

76-
obj_func = JuMP.objective_function(dual_model)
77-
7878
Qinv = inv(Q)
79+
80+
Qz_fn = VecAffExprMatrix(
81+
Qz,
82+
[flatten_y(y_vars); p_vars];
83+
backend=backend
84+
)
85+
86+
obj_fn = QuadExprMatrix(
87+
JuMP.objective_function(dual_model),
88+
[flatten_y(y_vars); p_vars; z_vars];
89+
backend=backend
90+
)
91+
7992
return (y_pred, param_value) -> begin
8093
y_pred_proj = proj_fn(y_pred)
8194

82-
Qz_val = JuMP.value.(vr -> _find_and_return_value(vr,
83-
[reduce(vcat, y_vars), p_vars],
84-
[reduce(vcat, y_pred_proj), param_value]),
85-
Qz
86-
)
95+
Qz_val = Qz_fn([flatten_y(y_pred_proj); param_value])
8796

8897
z = Qinv * Qz_val
8998

90-
JuMP.value.(vr -> _find_and_return_value(vr,
91-
[reduce(vcat, y_vars), p_vars, z_vars],
92-
[reduce(vcat, y_pred_proj), param_value, z]),
93-
obj_func
94-
)
99+
obj_fn([flatten_y(y_pred_proj); param_value; z])
95100
end
96101
end
97102

src/layers/generic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function jump_builder(decomposition::AbstractDecomposition, proj_fn::Function, d
3737
lock(completion_model.ext[:🔒])
3838
try
3939
JuMP.set_parameter_value.(p_ref, param_value)
40-
JuMP.set_parameter_value.(reduce(vcat, y_ref), reduce(vcat, proj_fn(y_pred)))
40+
JuMP.set_parameter_value.(flatten_y(y_ref), flatten_y(proj_fn(y_pred)))
4141

4242
JuMP.optimize!(completion_model)
4343
JuMP.assert_is_solved_and_feasible(completion_model)
@@ -63,7 +63,7 @@ function _make_completion_model(decomposition::AbstractDecomposition, dual_model
6363
# mark y as parameters (optimizing over z only)
6464
p_ref = getindex.(ref_map, get_p(dual_model, decomposition))
6565
y_ref = getindex.(ref_map, get_y_dual(dual_model, decomposition))
66-
y_ref_flat = reduce(vcat, y_ref)
66+
y_ref_flat = flatten_y(y_ref)
6767
JuMP.@constraint(completion_model, y_ref_flat .∈ MOI.Parameter.(zeros(length(y_ref_flat))))
6868

6969
return completion_model, (p_ref, y_ref, ref_map)

src/projection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ function get_y_sets(dual_model, decomposition)
3131
isnothing(set) ? nothing : MOI.get(dual_model, MOI.ConstraintSet(), set)
3232
for set in get_y_constraint(dual_model, decomposition)
3333
]
34-
end
34+
end

0 commit comments

Comments
 (0)