Skip to content

Commit af3c8f0

Browse files
committed
Initial commit of JAX implementation
1 parent 483e922 commit af3c8f0

File tree

8 files changed

+1503
-26
lines changed

8 files changed

+1503
-26
lines changed

.appveyor.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ build_script:
1919
- pip install .
2020

2121
test_script:
22-
- pytest
22+
- pytest cvxpylayers/torch cvxpylayers/tensorflow

.travis.yml

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
matrix:
22
include:
3-
- os: linux
4-
dist: xenial
5-
language: python
6-
python: "3.5"
73
- os: linux
84
dist: xenial
95
language: python
@@ -12,10 +8,6 @@ matrix:
128
dist: xenial
139
language: python
1410
python: "3.7"
15-
- os: linux
16-
dist: bionic
17-
language: python
18-
python: "3.5"
1911
- os: linux
2012
dist: bionic
2113
language: python
@@ -31,6 +23,7 @@ before_install:
3123

3224
install:
3325
- pip install --upgrade pip
26+
- pip install jax==0.2.12 jaxlib==0.1.64
3427
- pip install tensorflow pytest flake8 jupyter matplotlib sklearn tqdm
3528
- pip install torch==1.3.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
3629
- pip install .

README.md

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# cvxpylayers
66

77
cvxpylayers is a Python library for constructing differentiable convex
8-
optimization layers in PyTorch and TensorFlow using CVXPY.
8+
optimization layers in PyTorch, JAX, and TensorFlow using CVXPY.
99
A convex optimization layer solves a parametrized convex optimization problem
1010
in the forward pass to produce a solution.
1111
It computes the derivative of the solution with respect to
@@ -36,28 +36,32 @@ cvxpylayers.
3636
pip install cvxpylayers
3737
```
3838

39-
Our package includes convex optimization layers for PyTorch and TensorFlow 2.0;
39+
Our package includes convex optimization layers for
40+
PyTorch, JAX, and TensorFlow 2.0;
4041
the layers are functionally equivalent. You will need to install
41-
[PyTorch](https://pytorch.org) or [TensorFlow](https://www.tensorflow.org)
42+
[PyTorch](https://pytorch.org),
43+
[JAX](https://github.com/google/jax), or
44+
[TensorFlow](https://www.tensorflow.org)
4245
separately, which can be done by following the instructions on their websites.
4346

4447
cvxpylayers has the following dependencies:
4548
* Python 3
4649
* [NumPy](https://pypi.org/project/numpy/)
4750
* [CVXPY](https://github.com/cvxgrp/cvxpy) >= 1.1.a4
48-
* [TensorFlow](https://tensorflow.org) >= 2.0 or [PyTorch](https://pytorch.org) >= 1.0
51+
* [PyTorch](https://pytorch.org) >= 1.0, [JAX](https://github.com/google/jax) >= 0.2.12, or [TensorFlow](https://tensorflow.org) >= 2.0
4952
* [diffcp](https://github.com/cvxgrp/diffcp) >= 1.0.13
5053

5154
## Usage
52-
Below are usage examples of our PyTorch and TensorFlow layers. Note that
53-
the parametrized convex optimization problems must be constructed in CVXPY,
54-
using [DPP](https://www.cvxpy.org/tutorial/advanced/index.html#disciplined-parametrized-programming).
55+
Below are usage examples of our PyTorch, JAX, and TensorFlow layers.
56+
Note that the parametrized convex optimization problems must be constructed
57+
in CVXPY, using
58+
[DPP](https://www.cvxpy.org/tutorial/advanced/index.html#disciplined-parametrized-programming).
5559

5660
### PyTorch
5761

5862
```python
5963
import cvxpy as cp
60-
import torch
64+
import torch
6165
from cvxpylayers.torch import CvxpyLayer
6266

6367
n, m = 2, 3
@@ -82,6 +86,36 @@ solution.sum().backward()
8286

8387
Note: `CvxpyLayer` cannot be traced with `torch.jit`.
8488

89+
### JAX
90+
```python
91+
import cvxpy as cp
92+
import jax
93+
from cvxpylayers.jax import CvxpyLayer
94+
95+
n, m = 2, 3
96+
x = cp.Variable(n)
97+
A = cp.Parameter((m, n))
98+
b = cp.Parameter(m)
99+
constraints = [x >= 0]
100+
objective = cp.Minimize(0.5 * cp.pnorm(A @ x - b, p=1))
101+
problem = cp.Problem(objective, constraints)
102+
assert problem.is_dpp()
103+
104+
cvxpylayer = CvxpyLayer(problem, parameters=[A, b], variables=[x])
105+
key = jax.random.PRNGKey(0)
106+
key, k1, k2 = jax.random.split(key, 3)
107+
A_jax = jax.random.normal(k1, shape=(m, n))
108+
b_jax = jax.random.normal(k2, shape=(m,))
109+
110+
solution, = cvxpylayer(A_jax, b_jax)
111+
112+
# compute the gradient of the summed solution with respect to A, b
113+
dcvxpylayer = jax.grad(lambda A, b: sum(cvxpylayer(A, b)[0]), argnums=[0, 1])
114+
gradA, gradb = dcvxpylayer(A_jax, b_jax)
115+
```
116+
117+
Note: `CvxpyLayer` cannot be traced with the JAX `jit` or `vmap` operations.
118+
85119
### TensorFlow 2
86120
```python
87121
import cvxpy as cp
@@ -118,11 +152,11 @@ Starting with version 0.1.3, cvxpylayers can also differentiate through log-log
118152
import cvxpy as cp
119153
import torch
120154
from cvxpylayers.torch import CvxpyLayer
121-
155+
122156
x = cp.Variable(pos=True)
123157
y = cp.Variable(pos=True)
124158
z = cp.Variable(pos=True)
125-
159+
126160
a = cp.Parameter(pos=True, value=2.)
127161
b = cp.Parameter(pos=True, value=1.)
128162
c = cp.Parameter(value=0.5)
@@ -168,14 +202,9 @@ To install `pytest`, run:
168202
pip install pytest
169203
```
170204

171-
To run the tests for `torch`, in the main directory of this repository, run:
172-
```bash
173-
pytest cvxpylayers/torch
174-
```
175-
176-
To run the tests for `tensorflow`, in the main directory of this repository, run:
205+
Execute the tests from the main directory of this repository with:
177206
```bash
178-
pytest cvxpylayers/tensorflow
207+
pytest cvxpylayers/{torch,jax,tensorflow}
179208
```
180209

181210
## Projects using cvxpylayers

cvxpylayers/jax/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from cvxpylayers.jax.cvxpylayer import CvxpyLayer # noqa: F401

0 commit comments

Comments
 (0)