@@ -29,7 +29,7 @@ def projection_non_negative(tree: Any) -> Any:
29
29
30
30
.. math::
31
31
32
- \underset{p}{\text{argmin}} ~ || x - p| |_2^2 \quad
32
+ \underset{p}{\text{argmin}} ~ \| x - p\ |_2^2 \quad
33
33
\textrm{subject to} \quad p \ge 0
34
34
35
35
where :math:`x` is the input tree.
@@ -43,16 +43,12 @@ def projection_non_negative(tree: Any) -> Any:
43
43
return jax .tree .map (jax .nn .relu , tree )
44
44
45
45
46
- def _clip_safe (leaf , lower , upper ):
47
- return jnp .clip (jnp .asarray (leaf ), lower , upper )
48
-
49
-
50
46
def projection_box (tree : Any , lower : Any , upper : Any ) -> Any :
51
47
r"""Projection onto box constraints.
52
48
53
49
.. math::
54
50
55
- \underset{p}{\text{argmin}} ~ || x - p| |_2^2 \quad \textrm{subject to} \quad
51
+ \underset{p}{\text{argmin}} ~ \| x - p\ |_2^2 \quad \textrm{subject to} \quad
56
52
\text{lower} \le p \le \text{upper}
57
53
58
54
where :math:`x` is the input tree.
@@ -67,38 +63,38 @@ def projection_box(tree: Any, lower: Any, upper: Any) -> Any:
67
63
Returns:
68
64
projected tree, with the same structure as ``tree``.
69
65
"""
70
- return jax .tree .map (_clip_safe , tree , lower , upper )
66
+ return jax .tree .map (jnp . clip , tree , lower , upper )
71
67
72
68
73
- def projection_hypercube (tree : Any , scale : Any = 1.0 ) -> Any :
69
+ def projection_hypercube (tree : Any , scale : Any = 1 ) -> Any :
74
70
r"""Projection onto the (unit) hypercube.
75
71
76
72
.. math::
77
73
78
- \underset{p}{\text{argmin}} ~ || x - p| |_2^2 \quad \textrm{subject to} \quad
74
+ \underset{p}{\text{argmin}} ~ \| x - p\ |_2^2 \quad \textrm{subject to} \quad
79
75
0 \le p \le \text{scale}
80
76
81
77
where :math:`x` is the input tree.
82
78
83
- By default, we project to the unit hypercube (`scale=1.0 `).
79
+ By default, we project to the unit hypercube (`scale=1`).
84
80
85
81
This is a convenience wrapper around
86
82
:func:`projection_box <optax.projections.projection_box>`.
87
83
88
84
Args:
89
85
tree: tree to project.
90
- scale: scale of the hypercube, a scalar or a tree (default: 1.0 ).
86
+ scale: scale of the hypercube, a scalar or a tree (default: 1).
91
87
92
88
Returns:
93
89
projected tree, with the same structure as ``tree``.
94
90
"""
95
- return projection_box (tree , lower = 0.0 , upper = scale )
91
+ return projection_box (tree , lower = 0 , upper = scale )
96
92
97
93
98
94
@jax .custom_jvp
99
95
def _projection_unit_simplex (values : chex .Array ) -> chex .Array :
100
96
"""Projection onto the unit simplex."""
101
- s = 1.0
97
+ s = 1
102
98
n_features = values .shape [0 ]
103
99
u = jnp .sort (values )[::- 1 ]
104
100
cumsum_u = jnp .cumsum (u )
@@ -121,22 +117,22 @@ def _projection_unit_simplex_jvp(
121
117
return primal_out , tangent_out
122
118
123
119
124
- def projection_simplex (tree : Any , scale : chex .Numeric = 1.0 ) -> Any :
120
+ def projection_simplex (tree : Any , scale : chex .Numeric = 1 ) -> Any :
125
121
r"""Projection onto a simplex.
126
122
127
123
This function solves the following constrained optimization problem,
128
124
where ``x`` is the input tree.
129
125
130
126
.. math::
131
127
132
- \underset{p}{\text{argmin}} ~ || x - p| |_2^2 \quad \textrm{subject to} \quad
128
+ \underset{p}{\text{argmin}} ~ \| x - p\ |_2^2 \quad \textrm{subject to} \quad
133
129
p \ge 0, p^\top 1 = \text{scale}
134
130
135
131
By default, the projection is onto the probability simplex (unit simplex).
136
132
137
133
Args:
138
134
tree: tree to project.
139
- scale: value the projected tree should sum to (default: 1.0 ).
135
+ scale: value the projected tree should sum to (default: 1).
140
136
141
137
Returns:
142
138
projected tree, a tree with the same structure as ``tree``.
@@ -156,25 +152,21 @@ def projection_simplex(tree: Any, scale: chex.Numeric = 1.0) -> Any:
156
152
157
153
.. versionadded:: 0.2.3
158
154
"""
159
- if scale is None :
160
- scale = 1.0
161
-
162
155
values , unravel_fn = flatten_util .ravel_pytree (tree )
163
156
new_values = scale * _projection_unit_simplex (values / scale )
164
-
165
157
return unravel_fn (new_values )
166
158
167
159
168
- def projection_l1_sphere (tree : Any , scale : float = 1.0 ) -> Any :
160
+ def projection_l1_sphere (tree : Any , scale : chex . Numeric = 1 ) -> Any :
169
161
r"""Projection onto the l1 sphere.
170
162
171
163
This function solves the following constrained optimization problem,
172
164
where ``x`` is the input tree.
173
165
174
166
.. math::
175
167
176
- \underset{y}{\text{argmin}} ~ || x - y| |_2^2 \quad \textrm{subject to} \quad
177
- ||y| |_1 = \text{scale}
168
+ \underset{y}{\text{argmin}} ~ \| x - y\ |_2^2 \quad \textrm{subject to} \quad
169
+ \|y\ |_1 = \text{scale}
178
170
179
171
Args:
180
172
tree: tree to project.
@@ -189,16 +181,16 @@ def projection_l1_sphere(tree: Any, scale: float = 1.0) -> Any:
189
181
return otu .tree_mul (tree_sign , tree_abs_proj )
190
182
191
183
192
- def projection_l1_ball (tree : Any , scale : float = 1.0 ) -> Any :
184
+ def projection_l1_ball (tree : Any , scale : chex . Numeric = 1 ) -> Any :
193
185
r"""Projection onto the l1 ball.
194
186
195
187
This function solves the following constrained optimization problem,
196
188
where ``x`` is the input tree.
197
189
198
190
.. math::
199
191
200
- \underset{y}{\text{argmin}} ~ || x - y| |_2^2 \quad \textrm{subject to} \quad
201
- ||y| |_1 \le \text{scale}
192
+ \underset{y}{\text{argmin}} ~ \| x - y\ |_2^2 \quad \textrm{subject to} \quad
193
+ \|y\ |_1 \le \text{scale}
202
194
203
195
Args:
204
196
tree: tree to project.
@@ -229,16 +221,16 @@ def projection_l1_ball(tree: Any, scale: float = 1.0) -> Any:
229
221
)
230
222
231
223
232
- def projection_l2_sphere (tree : Any , scale : float = 1.0 ) -> Any :
224
+ def projection_l2_sphere (tree : Any , scale : chex . Numeric = 1 ) -> Any :
233
225
r"""Projection onto the l2 sphere.
234
226
235
227
This function solves the following constrained optimization problem,
236
228
where ``x`` is the input tree.
237
229
238
230
.. math::
239
231
240
- \underset{y}{\text{argmin}} ~ || x - y| |_2^2 \quad \textrm{subject to} \quad
241
- ||y| |_2 = \text{value}
232
+ \underset{y}{\text{argmin}} ~ \| x - y\ |_2^2 \quad \textrm{subject to} \quad
233
+ \|y\ |_2 = \text{value}
242
234
243
235
Args:
244
236
tree: tree to project.
@@ -253,16 +245,16 @@ def projection_l2_sphere(tree: Any, scale: float = 1.0) -> Any:
253
245
return otu .tree_scale (factor , tree )
254
246
255
247
256
- def projection_l2_ball (tree : Any , scale : float = 1.0 ) -> Any :
248
+ def projection_l2_ball (tree : Any , scale : chex . Numeric = 1 ) -> Any :
257
249
r"""Projection onto the l2 ball.
258
250
259
251
This function solves the following constrained optimization problem,
260
252
where ``x`` is the input tree.
261
253
262
254
.. math::
263
255
264
- \underset{y}{\text{argmin}} ~ || x - y| |_2^2 \quad \textrm{subject to} \quad
265
- ||y| |_2 \le \text{scale}
256
+ \underset{y}{\text{argmin}} ~ \| x - y\ |_2^2 \quad \textrm{subject to} \quad
257
+ \|y\ |_2 \le \text{scale}
266
258
267
259
Args:
268
260
tree: tree to project.
@@ -273,26 +265,20 @@ def projection_l2_ball(tree: Any, scale: float = 1.0) -> Any:
273
265
274
266
.. versionadded:: 0.2.4
275
267
"""
276
- l2_norm = otu .tree_l2_norm (tree )
277
- factor = scale / l2_norm
278
- return jax .lax .cond (
279
- l2_norm <= scale ,
280
- lambda tree : tree ,
281
- lambda tree : otu .tree_scale (factor , tree ),
282
- operand = tree ,
283
- )
268
+ factor = scale / otu .tree_l2_norm (tree ).clip (min = scale )
269
+ return otu .tree_scale (factor , tree )
284
270
285
271
286
- def projection_linf_ball (tree : Any , scale : float = 1.0 ) -> Any :
272
+ def projection_linf_ball (tree : Any , scale : chex . Numeric = 1 ) -> Any :
287
273
r"""Projection onto the l-infinity ball.
288
274
289
275
This function solves the following constrained optimization problem,
290
276
where ``x`` is the input tree.
291
277
292
278
.. math::
293
279
294
- \underset{y}{\text{argmin}} ~ || x - y| |_2^2 \quad \textrm{subject to} \quad
295
- ||y| |_{\infty} \le \text{scale}
280
+ \underset{y}{\text{argmin}} ~ \| x - y\ |_2^2 \quad \textrm{subject to} \quad
281
+ \|y\ |_{\infty} \le \text{scale}
296
282
297
283
Args:
298
284
tree: tree to project.
@@ -301,6 +287,6 @@ def projection_linf_ball(tree: Any, scale: float = 1.0) -> Any:
301
287
Returns:
302
288
projected tree, with the same structure as ``tree``.
303
289
"""
304
- lower_tree = otu .tree_full_like (tree , - scale )
305
- upper_tree = otu .tree_full_like (tree , scale )
306
- return projection_box (tree , lower = lower_tree , upper = upper_tree )
290
+ lower = otu .tree_full_like (tree , - scale )
291
+ upper = otu .tree_full_like (tree , scale )
292
+ return projection_box (tree , lower , upper )
0 commit comments