Skip to content

Commit 9f1d45a

Browse files
committed
Minor cleanup for optax.projections.
1 parent cd06989 commit 9f1d45a

File tree

2 files changed

+33
-47
lines changed

2 files changed

+33
-47
lines changed

docs/api/projections.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ The Euclidean projection onto a set :math:`\mathcal{C}` is:
99
.. math::
1010
1111
\text{proj}_{\mathcal{C}}(u) :=
12-
\underset{v}{\text{argmin}} ~ ||u - v||^2_2 \textrm{ subject to } v \in \mathcal{C}.
12+
\underset{v}{\text{argmin}} ~ \|u - v\|^2_2 \textrm{ subject to } v \in \mathcal{C}.
1313
1414
For instance, here is an example how we can project parameters to the non-negative orthant::
1515

optax/projections/_projections.py

Lines changed: 32 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def projection_non_negative(tree: Any) -> Any:
2929
3030
.. math::
3131
32-
\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad
32+
\underset{p}{\text{argmin}} ~ \|x - p\|_2^2 \quad
3333
\textrm{subject to} \quad p \ge 0
3434
3535
where :math:`x` is the input tree.
@@ -43,16 +43,12 @@ def projection_non_negative(tree: Any) -> Any:
4343
return jax.tree.map(jax.nn.relu, tree)
4444

4545

46-
def _clip_safe(leaf, lower, upper):
47-
return jnp.clip(jnp.asarray(leaf), lower, upper)
48-
49-
5046
def projection_box(tree: Any, lower: Any, upper: Any) -> Any:
5147
r"""Projection onto box constraints.
5248
5349
.. math::
5450
55-
\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{subject to} \quad
51+
\underset{p}{\text{argmin}} ~ \|x - p\|_2^2 \quad \textrm{subject to} \quad
5652
\text{lower} \le p \le \text{upper}
5753
5854
where :math:`x` is the input tree.
@@ -67,38 +63,38 @@ def projection_box(tree: Any, lower: Any, upper: Any) -> Any:
6763
Returns:
6864
projected tree, with the same structure as ``tree``.
6965
"""
70-
return jax.tree.map(_clip_safe, tree, lower, upper)
66+
return jax.tree.map(jnp.clip, tree, lower, upper)
7167

7268

73-
def projection_hypercube(tree: Any, scale: Any = 1.0) -> Any:
69+
def projection_hypercube(tree: Any, scale: Any = 1) -> Any:
7470
r"""Projection onto the (unit) hypercube.
7571
7672
.. math::
7773
78-
\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{subject to} \quad
74+
\underset{p}{\text{argmin}} ~ \|x - p\|_2^2 \quad \textrm{subject to} \quad
7975
0 \le p \le \text{scale}
8076
8177
where :math:`x` is the input tree.
8278
83-
By default, we project to the unit hypercube (`scale=1.0`).
79+
By default, we project to the unit hypercube (`scale=1`).
8480
8581
This is a convenience wrapper around
8682
:func:`projection_box <optax.projections.projection_box>`.
8783
8884
Args:
8985
tree: tree to project.
90-
scale: scale of the hypercube, a scalar or a tree (default: 1.0).
86+
scale: scale of the hypercube, a scalar or a tree (default: 1).
9187
9288
Returns:
9389
projected tree, with the same structure as ``tree``.
9490
"""
95-
return projection_box(tree, lower=0.0, upper=scale)
91+
return projection_box(tree, lower=0, upper=scale)
9692

9793

9894
@jax.custom_jvp
9995
def _projection_unit_simplex(values: chex.Array) -> chex.Array:
10096
"""Projection onto the unit simplex."""
101-
s = 1.0
97+
s = 1
10298
n_features = values.shape[0]
10399
u = jnp.sort(values)[::-1]
104100
cumsum_u = jnp.cumsum(u)
@@ -121,22 +117,22 @@ def _projection_unit_simplex_jvp(
121117
return primal_out, tangent_out
122118

123119

124-
def projection_simplex(tree: Any, scale: chex.Numeric = 1.0) -> Any:
120+
def projection_simplex(tree: Any, scale: chex.Numeric = 1) -> Any:
125121
r"""Projection onto a simplex.
126122
127123
This function solves the following constrained optimization problem,
128124
where ``x`` is the input tree.
129125
130126
.. math::
131127
132-
\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{subject to} \quad
128+
\underset{p}{\text{argmin}} ~ \|x - p\|_2^2 \quad \textrm{subject to} \quad
133129
p \ge 0, p^\top 1 = \text{scale}
134130
135131
By default, the projection is onto the probability simplex (unit simplex).
136132
137133
Args:
138134
tree: tree to project.
139-
scale: value the projected tree should sum to (default: 1.0).
135+
scale: value the projected tree should sum to (default: 1).
140136
141137
Returns:
142138
projected tree, a tree with the same structure as ``tree``.
@@ -156,25 +152,21 @@ def projection_simplex(tree: Any, scale: chex.Numeric = 1.0) -> Any:
156152
157153
.. versionadded:: 0.2.3
158154
"""
159-
if scale is None:
160-
scale = 1.0
161-
162155
values, unravel_fn = flatten_util.ravel_pytree(tree)
163156
new_values = scale * _projection_unit_simplex(values / scale)
164-
165157
return unravel_fn(new_values)
166158

167159

168-
def projection_l1_sphere(tree: Any, scale: float = 1.0) -> Any:
160+
def projection_l1_sphere(tree: Any, scale: chex.Numeric = 1) -> Any:
169161
r"""Projection onto the l1 sphere.
170162
171163
This function solves the following constrained optimization problem,
172164
where ``x`` is the input tree.
173165
174166
.. math::
175167
176-
\underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad
177-
||y||_1 = \text{scale}
168+
\underset{y}{\text{argmin}} ~ \|x - y\|_2^2 \quad \textrm{subject to} \quad
169+
\|y\|_1 = \text{scale}
178170
179171
Args:
180172
tree: tree to project.
@@ -189,16 +181,16 @@ def projection_l1_sphere(tree: Any, scale: float = 1.0) -> Any:
189181
return otu.tree_mul(tree_sign, tree_abs_proj)
190182

191183

192-
def projection_l1_ball(tree: Any, scale: float = 1.0) -> Any:
184+
def projection_l1_ball(tree: Any, scale: chex.Numeric = 1) -> Any:
193185
r"""Projection onto the l1 ball.
194186
195187
This function solves the following constrained optimization problem,
196188
where ``x`` is the input tree.
197189
198190
.. math::
199191
200-
\underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad
201-
||y||_1 \le \text{scale}
192+
\underset{y}{\text{argmin}} ~ \|x - y\|_2^2 \quad \textrm{subject to} \quad
193+
\|y\|_1 \le \text{scale}
202194
203195
Args:
204196
tree: tree to project.
@@ -229,16 +221,16 @@ def projection_l1_ball(tree: Any, scale: float = 1.0) -> Any:
229221
)
230222

231223

232-
def projection_l2_sphere(tree: Any, scale: float = 1.0) -> Any:
224+
def projection_l2_sphere(tree: Any, scale: chex.Numeric = 1) -> Any:
233225
r"""Projection onto the l2 sphere.
234226
235227
This function solves the following constrained optimization problem,
236228
where ``x`` is the input tree.
237229
238230
.. math::
239231
240-
\underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad
241-
||y||_2 = \text{value}
232+
\underset{y}{\text{argmin}} ~ \|x - y\|_2^2 \quad \textrm{subject to} \quad
233+
\|y\|_2 = \text{value}
242234
243235
Args:
244236
tree: tree to project.
@@ -253,16 +245,16 @@ def projection_l2_sphere(tree: Any, scale: float = 1.0) -> Any:
253245
return otu.tree_scale(factor, tree)
254246

255247

256-
def projection_l2_ball(tree: Any, scale: float = 1.0) -> Any:
248+
def projection_l2_ball(tree: Any, scale: chex.Numeric = 1) -> Any:
257249
r"""Projection onto the l2 ball.
258250
259251
This function solves the following constrained optimization problem,
260252
where ``x`` is the input tree.
261253
262254
.. math::
263255
264-
\underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad
265-
||y||_2 \le \text{scale}
256+
\underset{y}{\text{argmin}} ~ \|x - y\|_2^2 \quad \textrm{subject to} \quad
257+
\|y\|_2 \le \text{scale}
266258
267259
Args:
268260
tree: tree to project.
@@ -273,26 +265,20 @@ def projection_l2_ball(tree: Any, scale: float = 1.0) -> Any:
273265
274266
.. versionadded:: 0.2.4
275267
"""
276-
l2_norm = otu.tree_l2_norm(tree)
277-
factor = scale / l2_norm
278-
return jax.lax.cond(
279-
l2_norm <= scale,
280-
lambda tree: tree,
281-
lambda tree: otu.tree_scale(factor, tree),
282-
operand=tree,
283-
)
268+
factor = scale / otu.tree_l2_norm(tree).clip(min=scale)
269+
return otu.tree_scale(factor, tree)
284270

285271

286-
def projection_linf_ball(tree: Any, scale: float = 1.0) -> Any:
272+
def projection_linf_ball(tree: Any, scale: chex.Numeric = 1) -> Any:
287273
r"""Projection onto the l-infinity ball.
288274
289275
This function solves the following constrained optimization problem,
290276
where ``x`` is the input tree.
291277
292278
.. math::
293279
294-
\underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad
295-
||y||_{\infty} \le \text{scale}
280+
\underset{y}{\text{argmin}} ~ \|x - y\|_2^2 \quad \textrm{subject to} \quad
281+
\|y\|_{\infty} \le \text{scale}
296282
297283
Args:
298284
tree: tree to project.
@@ -301,6 +287,6 @@ def projection_linf_ball(tree: Any, scale: float = 1.0) -> Any:
301287
Returns:
302288
projected tree, with the same structure as ``tree``.
303289
"""
304-
lower_tree = otu.tree_full_like(tree, -scale)
305-
upper_tree = otu.tree_full_like(tree, scale)
306-
return projection_box(tree, lower=lower_tree, upper=upper_tree)
290+
lower = otu.tree_full_like(tree, -scale)
291+
upper = otu.tree_full_like(tree, scale)
292+
return projection_box(tree, lower, upper)

0 commit comments

Comments
 (0)