5
5
# cvxpylayers
6
6
7
7
cvxpylayers is a Python library for constructing differentiable convex
8
- optimization layers in PyTorch and TensorFlow using CVXPY.
8
+ optimization layers in PyTorch, JAX, and TensorFlow using CVXPY.
9
9
A convex optimization layer solves a parametrized convex optimization problem
10
10
in the forward pass to produce a solution.
11
11
It computes the derivative of the solution with respect to
@@ -36,28 +36,32 @@ cvxpylayers.
36
36
pip install cvxpylayers
37
37
```
38
38
39
- Our package includes convex optimization layers for PyTorch and TensorFlow 2.0;
39
+ Our package includes convex optimization layers for
40
+ PyTorch, JAX, and TensorFlow 2.0;
40
41
the layers are functionally equivalent. You will need to install
41
- [ PyTorch] ( https://pytorch.org ) or [ TensorFlow] ( https://www.tensorflow.org )
42
+ [ PyTorch] ( https://pytorch.org ) ,
43
+ [ JAX] ( https://github.com/google/jax ) , or
44
+ [ TensorFlow] ( https://www.tensorflow.org )
42
45
separately, which can be done by following the instructions on their websites.
43
46
44
47
cvxpylayers has the following dependencies:
45
48
* Python 3
46
49
* [ NumPy] ( https://pypi.org/project/numpy/ )
47
50
* [ CVXPY] ( https://github.com/cvxgrp/cvxpy ) >= 1.1.a4
48
- * [ TensorFlow ] ( https://tensorflow .org ) >= 2.0 or [ PyTorch ] ( https://pytorch .org ) >= 1 .0
51
+ * [ PyTorch ] ( https://pytorch .org ) >= 1.0, [ JAX ] ( https://github.com/google/jax ) >= 0.2.12, or [ TensorFlow ] ( https://tensorflow .org ) >= 2 .0
49
52
* [ diffcp] ( https://github.com/cvxgrp/diffcp ) >= 1.0.13
50
53
51
54
## Usage
52
- Below are usage examples of our PyTorch and TensorFlow layers. Note that
53
- the parametrized convex optimization problems must be constructed in CVXPY,
54
- using [ DPP] ( https://www.cvxpy.org/tutorial/advanced/index.html#disciplined-parametrized-programming ) .
55
+ Below are usage examples of our PyTorch, JAX, and TensorFlow layers.
56
+ Note that the parametrized convex optimization problems must be constructed
57
+ in CVXPY, using
58
+ [ DPP] ( https://www.cvxpy.org/tutorial/advanced/index.html#disciplined-parametrized-programming ) .
55
59
56
60
### PyTorch
57
61
58
62
``` python
59
63
import cvxpy as cp
60
- import torch
64
+ import torch
61
65
from cvxpylayers.torch import CvxpyLayer
62
66
63
67
n, m = 2 , 3
@@ -82,6 +86,36 @@ solution.sum().backward()
82
86
83
87
Note: ` CvxpyLayer ` cannot be traced with ` torch.jit ` .
84
88
89
+ ### JAX
90
+ ``` python
91
+ import cvxpy as cp
92
+ import jax
93
+ from cvxpylayers.jax import CvxpyLayer
94
+
95
+ n, m = 2 , 3
96
+ x = cp.Variable(n)
97
+ A = cp.Parameter((m, n))
98
+ b = cp.Parameter(m)
99
+ constraints = [x >= 0 ]
100
+ objective = cp.Minimize(0.5 * cp.pnorm(A @ x - b, p = 1 ))
101
+ problem = cp.Problem(objective, constraints)
102
+ assert problem.is_dpp()
103
+
104
+ cvxpylayer = CvxpyLayer(problem, parameters = [A, b], variables = [x])
105
+ key = jax.random.PRNGKey(0 )
106
+ key, k1, k2 = jax.random.split(key, 3 )
107
+ A_jax = jax.random.normal(k1, shape = (m, n))
108
+ b_jax = jax.random.normal(k2, shape = (m,))
109
+
110
+ solution, = cvxpylayer(A_jax, b_jax)
111
+
112
+ # compute the gradient of the summed solution with respect to A, b
113
+ dcvxpylayer = jax.grad(lambda A , b : sum (cvxpylayer(A, b)[0 ]), argnums = [0 , 1 ])
114
+ gradA, gradb = dcvxpylayer(A_jax, b_jax)
115
+ ```
116
+
117
+ Note: ` CvxpyLayer ` cannot be traced with the JAX ` jit ` or ` vmap ` operations.
118
+
85
119
### TensorFlow 2
86
120
``` python
87
121
import cvxpy as cp
@@ -118,11 +152,11 @@ Starting with version 0.1.3, cvxpylayers can also differentiate through log-log
118
152
import cvxpy as cp
119
153
import torch
120
154
from cvxpylayers.torch import CvxpyLayer
121
-
155
+
122
156
x = cp.Variable(pos = True )
123
157
y = cp.Variable(pos = True )
124
158
z = cp.Variable(pos = True )
125
-
159
+
126
160
a = cp.Parameter(pos = True , value = 2 .)
127
161
b = cp.Parameter(pos = True , value = 1 .)
128
162
c = cp.Parameter(value = 0.5 )
@@ -168,14 +202,9 @@ To install `pytest`, run:
168
202
pip install pytest
169
203
```
170
204
171
- To run the tests for ` torch ` , in the main directory of this repository, run:
172
- ``` bash
173
- pytest cvxpylayers/torch
174
- ```
175
-
176
- To run the tests for ` tensorflow ` , in the main directory of this repository, run:
205
+ Execute the tests from the main directory of this repository with:
177
206
``` bash
178
- pytest cvxpylayers/tensorflow
207
+ pytest cvxpylayers/{torch,jax, tensorflow}
179
208
```
180
209
181
210
## Projects using cvxpylayers
0 commit comments