diff --git a/benchmarks/sparse_root_prox_gradient_benchmark.py b/benchmarks/sparse_root_prox_gradient_benchmark.py new file mode 100644 index 00000000..f9eb8822 --- /dev/null +++ b/benchmarks/sparse_root_prox_gradient_benchmark.py @@ -0,0 +1,101 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Implicit differentiation of lasso. +================================== +""" + +import time +from absl import app +from absl import flags + +import jax +import jax.numpy as jnp + +from jaxopt import objective +from jaxopt import prox +from jaxopt import ProximalGradient +from jaxopt import support +from jaxopt._src import linear_solve + +from sklearn import datasets +from sklearn import model_selection + + +def outer_objective(theta, init_inner, data, support, implicit_diff_solve): + """Validation loss.""" + X_tr, X_val, y_tr, y_val = data + # We use the bijective mapping lam = jnp.exp(theta) to ensure positivity. + lam = jnp.exp(theta) + + solver = ProximalGradient( + fun=objective.least_squares, + support=support, + prox=prox.prox_lasso, + implicit_diff=True, + implicit_diff_solve=implicit_diff_solve, + maxiter=5000, + tol=1e-10) + + # The format is run(init_params, hyperparams_prox, *args, **kwargs) + # where *args and **kwargs are passed to `fun`. + w_fit = solver.run(init_inner, lam, (X_tr, y_tr)).params + + y_pred = jnp.dot(X_val, w_fit) + loss_value = jnp.mean((y_pred - y_val) ** 2) + + # We return w_fit as auxiliary data. + # Auxiliary data is stored in the optimizer state (see below). + return loss_value, w_fit + + +X, y = datasets.make_regression( + n_samples=100, n_features=6000, n_informative=10, random_state=0) + +n_samples = X.shape[0] +data = model_selection.train_test_split(X, y, test_size=0.33, random_state=0) + +exp_lam_max = jnp.max(jnp.abs(X.T @ y)) +lam = jnp.log(exp_lam_max / (100 * n_samples)) + +init_inner = jnp.zeros(X.shape[1]) + + +t0 = time.time() +grad = jax.grad(outer_objective, has_aux=True) +gradient, coef_ = grad( + lam, init_inner, data, support.support_all, linear_solve.solve_cg) +gradient.block_until_ready() +delta_t = time.time() - t0 +desc='Gradients w/o support, CG' +print(f'{desc} ({delta_t:.3f} sec.): {gradient} ') + +t0 = time.time() +grad = jax.grad(outer_objective, has_aux=True) +gradient, coef_ = grad( + lam, init_inner, data, support.support_all, linear_solve.solve_normal_cg) +gradient.block_until_ready() +delta_t = time.time() - t0 +desc='Gradients w/o support, normal CG' +print(f'{desc} ({delta_t:.3f} sec.): {gradient} ') + +t0 = time.time() +grad = jax.grad(outer_objective, has_aux=True) +gradient, coef_ = grad( + lam, init_inner, data, support.support_nonzero, linear_solve.solve_cg) +gradient.block_until_ready() +delta_t = time.time() - t0 +desc='Gradients w/ masked support' +print(f'{desc} ({delta_t:.3f} sec.): {gradient}') diff --git a/benchmarks/sparse_root_vjp_benchmark.py b/benchmarks/sparse_root_vjp_benchmark.py new file mode 100644 index 00000000..53c312f3 --- /dev/null +++ b/benchmarks/sparse_root_vjp_benchmark.py @@ -0,0 +1,115 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmark VJP of Lasso, with and without specifying the support.""" + +import time +import numpy as onp +import jax +import jax.numpy as jnp + +from typing import Sequence +from absl import app +from sklearn import datasets + +from jaxopt import implicit_diff as idf +from jaxopt._src import linear_solve +from jaxopt import objective +from jaxopt import prox +from jaxopt import support +from jaxopt import tree_util +from jaxopt._src import test_util + + +def lasso_optimality_fun(params, lam, X, y): + step = params - jax.grad(objective.least_squares)(params, (X, y)) + return prox.prox_lasso(step, l1reg=lam, scaling=1.) - params + + +def get_vjp(lam, X, y, sol, support_fun, maxiter, solve_fn=linear_solve.solve_cg): + def solve(matvec, b): + return solve_fn(matvec, b, tol=1e-6, maxiter=maxiter) + + vjp = lambda g: idf.root_vjp(optimality_fun=lasso_optimality_fun, + support_fun=support_fun, + sol=sol, + args=(lam, X, y), + cotangent=g, + solve=solve)[0] + return vjp + + +def benchmark_vjp(vjp, X, supp=None, desc=''): + t0 = time.time() + result = jax.vmap(vjp)(jnp.eye(X.shape[1])) + result.block_until_ready() + delta_t = time.time() - t0 + + size_support = onp.sum(result != 0) + + if supp is not None: + result = result[supp] + print(f'{desc} ({delta_t:.3f} sec.): {result} ' + f'(size of the support: {size_support:d})') + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + X, y = datasets.make_regression(n_samples=100, n_features=1000, + n_informative=5, random_state=0) + print(f'Number of samples: {X.shape[0]:d}') + print(f'Number of features: {X.shape[1]:d}') + + lam = 1e-3 * onp.max(onp.abs(X.T @ y)) / X.shape[0] + print(f'Value of lambda: {lam:.5f}') + + sol = test_util.lasso_skl(X, y, lam, tol=1e-6, fit_intercept=False) + supp = (sol != 0) + print(f'Size of the support of the solution: {supp.sum():d}') + + optimality = tree_util.tree_l2_norm(lasso_optimality_fun(sol, lam, X, y)) + print(f'Optimality of the solution (L2-norm): {optimality:.5f}') + + jac_num = test_util.lasso_skl_jac(X, y, lam, tol=1e-8, eps=1e-8) + print(f'Numerical Jacobian: {jac_num[supp]}') + + # Compute the Jacobian wrt. lambda, without using the information about the + # support of the solution. This is the default behavior in JAXopt, and + # requires solving a linear system with a 1000x1000 dense matrix. Ignoring + # the support of the solution with CG leads to an inacurrate Jacobian. + vjp = get_vjp(lam, X, y, sol, support.support_all, maxiter=1000) + benchmark_vjp(vjp, X, supp=supp, desc='Jacobian w/o support, CG') + + vjp = get_vjp(lam, X, y, sol, support.support_all, maxiter=1000, + solve_fn=linear_solve.solve_normal_cg) + benchmark_vjp(vjp, X, supp=supp, desc='Jacobian w/o support, normal CG') + + # Compute the Jacobian wrt. lambda, restricting the data X and the solution + # to the support of the solution. This requires solving a linear system with + # a 4x4 dense matrix. Restricting the support this way is not jit-friendly, + # and will require new compilation if the size of the support changes. + vjp = get_vjp(lam, X[:, supp], y, sol[supp], support.support_all, maxiter=1000) + benchmark_vjp(vjp, X[:, supp], desc='Jacobian w/ restricted support') + + # Compute the Jacobian wrt. lambda, by masking the linear system to solve + # with the support of the solution. This requires solving a linear system with + # a 1000x1000 sparse matrix. Masking with the support is jit-friendly. + vjp = get_vjp(lam, X, y, sol, support.support_nonzero, maxiter=1000) + benchmark_vjp(vjp, X, supp=supp, desc='Jacobian w/ masked support') + + +if __name__ == '__main__': + app.run(main) diff --git a/docs/implicit_diff.rst b/docs/implicit_diff.rst index 25c98196..482767c6 100644 --- a/docs/implicit_diff.rst +++ b/docs/implicit_diff.rst @@ -58,6 +58,26 @@ We can also compose ``ridge_reg_solution`` with other functions:: * :ref:`sphx_glr_auto_examples_implicit_diff_lasso_implicit_diff.py` * :ref:`sphx_glr_auto_examples_implicit_diff_sparse_coding.py` +Non-smooth functions +-------------------- + +When the function :math:`f(x, \theta)` is non-smooth (e.g., in Lasso), implicit +differentiation can still be applied to compute the Jacobian +:math:`\partial x^\star(\theta)`. However, the linear system to solve to obtain +the Jacobian (or, alternatively, the vector-Jacobian product) must be restricted +to the (generalized) *support* :math:`S` of the solution :math:`x^\star`. To +give a hint to the linear solver called in ``root_vjp``, you may specify a +support function ``support`` that will restrict the linear system to the support +of the solution. + +The ``support`` function must return a pytree with the same structure and dtype +as the solution, where ``support(x)`` is equal to 1 for the coordinates :math:`j` +in the support (:math:`x_{j} \in S`), and 0 otherwise. The support function +depends on the non-smooth function being optimized; see :ref:`support functions +` for examples of support functions. Note that the +support function merely masks out coordinates outside of the support, making it +fully compatible with ``jit`` compilation. + Custom solvers -------------- @@ -92,3 +112,8 @@ of roots of functions. `_, Mathieu Blondel, Quentin Berthet, Marco Cuturi, Roy Frostig, Stephan Hoyer, Felipe Llinares-López, Fabian Pedregosa, Jean-Philippe Vert. ArXiv preprint. + + * `Implicit Differentiation for Fast Hyperparameter Selection in Non-Smooth Convex Learning + `_, + Quentin Bertrand, Quentin Klopfenstein, Mathurin Massias, Mathieu Blondel, Samuel Vaiter, Alexandre Gramfort, Joseph Salmon. + Journal of Machine Learning Research (JMLR). diff --git a/docs/non_smooth.rst b/docs/non_smooth.rst index 79c5a2a1..9bf48ac4 100644 --- a/docs/non_smooth.rst +++ b/docs/non_smooth.rst @@ -82,6 +82,23 @@ and autodiff of unrolled iterations if ``implicit_diff=False``. See the * :ref:`sphx_glr_auto_examples_implicit_diff_lasso_implicit_diff.py` * :ref:`sphx_glr_auto_examples_implicit_diff_sparse_coding.py` +When using implicit differentiation, you can optionally specify a support +function ``support`` to give a hint to the linear solver called in ``root_vjp`` +and only solve the linear system restricted to the support of the solution:: + + from jaxopt.support import support_nonzero + + def solution(l1reg): + pg = ProximalGradient(fun=least_squares, prox=prox_lasso, + support=support_nonzero, implicit_diff=True) + return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params + + # Both the solution & the Jacobian have the same support + print(solution(l1reg)) + print(jax.jacobian(solution)(l1reg)) + +See the :ref:`implicit differentiation ` section for more details. + .. _block_coordinate_descent: Block coordinate descent @@ -134,3 +151,24 @@ The following operators are available. jaxopt.prox.prox_group_lasso jaxopt.prox.prox_ridge jaxopt.prox.prox_non_negative_ridge + +.. _support_functions: + +Support functions +----------------- + +Support functions of the form :math:`S(x)` that returns 1 for all the +coordinates of :math:`x` in the support, and 0 otherwise: + +.. math:: + + S(x)_{j} := \begin{cases} 1 & \textrm{if $x_{j} \in S$} \\ 0 & \textrm{otherwise} \end{cases} + +The following support functions are available. + +.. autosummary:: + :toctree: _autosummary + + jaxopt.support.support_all + jaxopt.support.support_nonzero + jaxopt.support.support_group_nonzero diff --git a/jaxopt/_src/base.py b/jaxopt/_src/base.py index a767f9eb..0d27c212 100644 --- a/jaxopt/_src/base.py +++ b/jaxopt/_src/base.py @@ -204,10 +204,12 @@ def run(self, run = self._run if getattr(self, "implicit_diff", True): + support_fun = getattr(self, "support", None) reference_signature = getattr(self, "reference_signature", None) decorator = idf.custom_root( self.optimality_fun, has_aux=True, + support_fun=support_fun, solve=self.implicit_diff_solve, reference_signature=reference_signature) run = decorator(run) diff --git a/jaxopt/_src/implicit_diff.py b/jaxopt/_src/implicit_diff.py index 1eae2e62..0ff6fd26 100644 --- a/jaxopt/_src/implicit_diff.py +++ b/jaxopt/_src/implicit_diff.py @@ -15,12 +15,14 @@ """Implicit differentiation of roots and fixed points.""" import inspect +import functools from typing import Any from typing import Callable from typing import Optional from typing import Tuple import jax +import jax.numpy as jnp from jaxopt._src import base from jaxopt._src import linear_solve @@ -28,9 +30,11 @@ from jaxopt._src.tree_util import tree_mul from jaxopt._src.tree_util import tree_scalar_mul from jaxopt._src.tree_util import tree_sub +from jaxopt._src.tree_util import tree_map def root_vjp(optimality_fun: Callable, + support_fun: Callable, sol: Any, args: Tuple, cotangent: Any, @@ -41,6 +45,7 @@ def root_vjp(optimality_fun: Callable, Args: optimality_fun: the optimality function to use. + support_fun: the support function. sol: solution / root (pytree). args: tuple containing the arguments with respect to which we wish to differentiate ``sol`` against. @@ -57,15 +62,18 @@ def fun_sol(sol): # We close over the arguments. return optimality_fun(sol, *args) + support = support_fun(sol) _, vjp_fun_sol = jax.vjp(fun_sol, sol) # Compute the multiplication A^T u = (u^T A)^T. - matvec = lambda u: vjp_fun_sol(u)[0] + def matvec(u): + Au = vjp_fun_sol(u)[0] + return tree_mul(Au, support) # The solution of A^T u = v, where # A = jacobian(optimality_fun, argnums=0) # v = -cotangent. - v = tree_scalar_mul(-1, cotangent) + v = tree_scalar_mul(-1, tree_mul(cotangent, support)) u = solve(matvec, v) def fun_args(*args): @@ -168,8 +176,8 @@ def map_back(out_args): return out_args, out_kwargs, map_back -def _custom_root(solver_fun, optimality_fun, solve, has_aux, - reference_signature=None): +def _custom_root(solver_fun, optimality_fun, support_fun, solve, + has_aux, reference_signature=None): # When caling through `jax.custom_vjp`, jax attempts to resolve all # arguments passed by keyword to positions (this is in order to # match against a `nondiff_argnums` parameter that we do not use @@ -233,8 +241,9 @@ def solver_fun_bwd(tup, cotangent): "both of which are currently unsupported.") # Compute VJPs w.r.t. args. - vjps = root_vjp(optimality_fun=optimality_fun, sol=sol, - args=ba_args[1:], cotangent=cotangent, solve=solve) + vjps = root_vjp(optimality_fun=optimality_fun, support_fun=support_fun, + sol=sol, args=ba_args[1:], cotangent=cotangent, + solve=solve) # Prepend None as the vjp for init_params. vjps = (None,) + vjps @@ -255,6 +264,7 @@ def wrapped_solver_fun(*args, **kwargs): def custom_root(optimality_fun: Callable, has_aux: bool = False, + support_fun: Optional[Callable] = None, solve: Callable = linear_solve.solve_normal_cg, reference_signature: Optional[Callable] = None): """Decorator for adding implicit differentiation to a root solver. @@ -264,6 +274,10 @@ def custom_root(optimality_fun: Callable, The invariant is ``optimality_fun(sol, *args) == 0`` at the solution / root ``sol``. has_aux: whether the decorated solver function returns auxiliary data. + support_fun: optional support function ``support_fun(params)``, returning + the support of a pytree ``params``. This function returns a pytree with + the same structure and dtypes as ``params``, equal to 1 for the + coordinates of ``params`` in the support, and 0 otherwise. solve: a linear solver of the form ``solve(matvec, b)``. reference_signature: optional function signature (i.e. arguments and keyword arguments), with which the @@ -282,9 +296,12 @@ def custom_root(optimality_fun: Callable, if solve is None: solve = linear_solve.solve_normal_cg + if support_fun is None: + support_fun = functools.partial(tree_map, jnp.ones_like) + def wrapper(solver_fun): - return _custom_root(solver_fun, optimality_fun, solve, has_aux, - reference_signature) + return _custom_root(solver_fun, optimality_fun, support_fun, solve, + has_aux, reference_signature) return wrapper diff --git a/jaxopt/_src/proximal_gradient.py b/jaxopt/_src/proximal_gradient.py index 0fa39b72..81920d48 100644 --- a/jaxopt/_src/proximal_gradient.py +++ b/jaxopt/_src/proximal_gradient.py @@ -31,6 +31,7 @@ from jaxopt._src import base from jaxopt._src import loop from jaxopt._src.prox import prox_none +from jaxopt._src.support import support_all from jaxopt._src.tree_util import tree_add_scalar_mul from jaxopt._src.tree_util import tree_l2_norm from jaxopt._src.tree_util import tree_sub @@ -104,6 +105,11 @@ class ProximalGradient(base.IterativeSolver): prox: proximity operator associated with the function ``non_smooth``. It should be of the form ``prox(params, hyperparams_prox, scale=1.0)``. See ``jaxopt.prox`` for examples. + support: a function of the form ``support(params)``, returning + the support of a pytree ``params``. This function returns a pytree with + the same structure and dtypes as ``params``, equal to 1 for the + coordinates of ``params`` in the support, and 0 otherwise. + See ``jaxopt.support`` for examples. stepsize: a stepsize to use (if <= 0, use backtracking line search), or a callable specifying the **positive** stepsize to use at each iteration. maxiter: maximum number of proximal gradient descent iterations. @@ -130,6 +136,7 @@ class ProximalGradient(base.IterativeSolver): """ fun: Callable prox: Callable = prox_none + support: Callable = support_all stepsize: Union[float, Callable] = 0.0 maxiter: int = 500 maxls: int = 15 diff --git a/jaxopt/_src/support.py b/jaxopt/_src/support.py new file mode 100644 index 00000000..db03028a --- /dev/null +++ b/jaxopt/_src/support.py @@ -0,0 +1,82 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Support functions.""" + +from typing import Any + +import jax.numpy as jnp + +from jaxopt._src import tree_util + + +def support_all(x: Any): + r"""Support function where all the coordinates are in the support. + + If :math:`S` is the support set, then :math:`\forall i`, :math:`x_{i} \in S`. + + Args: + x: input pytree. + Returns + output pytree, with the same structure and dtypes as ``x``, where all the + coordinates equal to 1. + """ + return tree_util.tree_map(jnp.ones_like, x) + + +def support_nonzero(x: Any): + r"""Support function where the support corresponds to non-zero coordinates. + + If :math:`S` is the support set, then :math:`\forall i`, + + .. math:: + + x_{i} \in S \Leftrightarrow x_{i} \neq 0 + + This support function is typically used for sparse objects with unstructured + sparsity patterns, e.g., with the operators ``jaxopt.prox.prox_lasso`` or + ``jaxopt.prox.prox_elastic_net``. + + Args: + x: input pytree. + Returns + output pytree, with the same structure and dtypes as ``x``, equal to 1 + if ``x[i] != 0``, and 0 otherwise. + """ + fun = lambda u: (u != 0).astype(u.dtype) + return tree_util.tree_map(fun, x) + + +def support_group_nonzero(x: Any): + r"""Support function where the support corresponds to groups of non-zero + coordinates. + + If :math:`S` is the support set, then :math:`\forall i`, + + .. math:: + x_{i} \in S \Leftrightarrow x \not\equiv 0 + + Blocks can be grouped using ``jax.vmap``. This support function is typically + used for sparse objects with structured sparsity patterns, e.g., with the + operator ``jaxopt.prox.prox_group_lasso``. + + Args: + x: input pytree. + Returns + output pytree, with the same structure and dtypes as ``x``, where all the + coordinates are equal to 1 if there exists an ``x[i] != 0``, and all equal + to 0 otherwise. + """ + fun = lambda u: jnp.any(u != 0) * jnp.ones_like(u) + return tree_util.tree_map(fun, x) diff --git a/jaxopt/support.py b/jaxopt/support.py new file mode 100644 index 00000000..ea43ebf6 --- /dev/null +++ b/jaxopt/support.py @@ -0,0 +1,17 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jaxopt._src.support import support_all +from jaxopt._src.support import support_nonzero +from jaxopt._src.support import support_group_nonzero diff --git a/tests/implicit_diff_test.py b/tests/implicit_diff_test.py index ac928b0c..519e04b1 100644 --- a/tests/implicit_diff_test.py +++ b/tests/implicit_diff_test.py @@ -21,6 +21,9 @@ import jax.numpy as jnp from jaxopt import implicit_diff as idf +from jaxopt import objective +from jaxopt import prox +from jaxopt import support from jaxopt._src import test_util from sklearn import datasets @@ -44,6 +47,19 @@ def ridge_solver_with_kwargs(init_params, **kw): return ridge_solver(init_params, kw['lam'], kw['X'], kw['y']) +def lasso_optimality_fun(params, lam, sol, X, y): + del sol # not used + step = params - jax.grad(objective.least_squares)(params, (X, y)) + return prox.prox_lasso(step, l1reg=lam, scaling=1.) - params + + +def lasso_oracle_solver(init_params, lam, sol, X, y): + # Jax-compatible function that returns the solution of Lasso, + # obtained with an oracle (e.g., jaxopt.test_util.lasso_skl) + del init_params, lam, X, y # not used + return sol + + class ImplicitDiffTest(test_util.JaxoptTestCase): def test_root_vjp(self): @@ -52,6 +68,7 @@ def test_root_vjp(self): lam = 5.0 sol = ridge_solver(None, lam, X, y) vjp = lambda g: idf.root_vjp(optimality_fun=optimality_fun, + support_fun=jnp.ones_like, sol=sol, args=(lam, X, y), cotangent=g)[0] # vjp w.r.t. lam @@ -60,6 +77,63 @@ def test_root_vjp(self): J_num = test_util.ridge_solver_jac(X, y, lam, eps=1e-4) self.assertArraysAllClose(J, J_num, atol=5e-2) + def test_root_vjp_support(self): + X, y = datasets.make_regression(n_samples=10, n_features=10, + n_informative=3, random_state=0) + lam = 5.0 + sol = test_util.lasso_skl(X, y, lam, tol=1e-5, fit_intercept=False) + vjp = lambda g: idf.root_vjp(optimality_fun=lasso_optimality_fun, + support_fun=support.support_nonzero, + sol=sol, + args=(lam, sol, X, y), + cotangent=g)[0] # vjp w.r.t. lam + I = jnp.eye(len(sol)) + J = jax.vmap(vjp)(I) + self.assertArraysEqual( + support.support_nonzero(J), + support.support_nonzero(sol) + ) + J_num = test_util.lasso_skl_jac(X, y, lam, eps=1e-4) + self.assertArraysEqual( + support.support_nonzero(J_num), + support.support_nonzero(sol) + ) + self.assertArraysAllClose(J, J_num, atol=5e-2) + + def test_root_vjp_support_jit(self): + X, y = datasets.make_regression(n_samples=10, n_features=10, + n_informative=3, random_state=0) + lam = 5.0 + sol = test_util.lasso_skl(X, y, lam, tol=1e-5, fit_intercept=False) + vjp = lambda g, sol, X: idf.root_vjp(optimality_fun=lasso_optimality_fun, + support_fun=support.support_nonzero, + sol=sol, + args=(lam, sol, X, y), + cotangent=g)[0] # vjp w.r.t. lam + jacobian = jax.jit(jax.vmap(vjp, in_axes=(0, None, None))) + I = jnp.eye(len(sol)) + J = jacobian(I, sol, X) + J_num = test_util.lasso_skl_jac(X, y, lam, eps=1e-4) + self.assertArraysAllClose(J, J_num, atol=5e-2) + + # Take an arbitrary coordinate in the support and mask it in X to get + # a solution with a smaller support. + support_idx = (sol != 0).argmax() + X[:, support_idx] = 0. + sol_sub = test_util.lasso_skl(X, y, lam, tol=1e-5, fit_intercept=False) + self.assertLess( + support.support_nonzero(sol_sub).sum(), + support.support_nonzero(sol).sum() + ) + + # Compute the Jacobian matrix with the smaller-support solution, and verify + # that the function `jacobian` has only been compiled once (no + # recompilation necessary, despite a smaller support). + J_sub = jacobian(I, sol_sub, X) + J_sub_num = test_util.lasso_skl_jac(X, y, lam, eps=1e-4) + self.assertArraysAllClose(J_sub, J_sub_num, atol=5e-2) + self.assertEqual(jacobian._cache_size(), 1) + def test_root_jvp(self): X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) optimality_fun = jax.grad(ridge_objective) @@ -100,6 +174,28 @@ def ridge_solver_with_aux(init_params, lam, X, y): J, _ = jax.jacrev(ridge_solver_decorated, argnums=1)(None, lam, X=X, y=y) self.assertArraysAllClose(J, J_num, atol=5e-2) + def test_custom_root_support(self): + X, y = datasets.make_regression(n_samples=10, n_features=10, + n_informative=3, random_state=0) + lam = 5.0 + lasso_solver_decorated = idf.custom_root( + optimality_fun=lasso_optimality_fun, + support_fun=support.support_nonzero)(lasso_oracle_solver) + sol = test_util.lasso_skl(X, y, lam, tol=1e-5, fit_intercept=False) + sol_decorated = lasso_solver_decorated(None, lam, sol, X, y) + self.assertArraysEqual(sol, sol_decorated) # The solver uses an oracle + J = jax.jacrev(lasso_solver_decorated, argnums=1)(None, lam, sol, X, y) + self.assertArraysEqual( + support.support_nonzero(J), + support.support_nonzero(sol) + ) + J_num = test_util.lasso_skl_jac(X, y, lam, eps=1e-4) + self.assertArraysEqual( + support.support_nonzero(J_num), + support.support_nonzero(sol) + ) + self.assertArraysAllClose(J, J_num, atol=5e-2) + def test_custom_fixed_point(self): X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) grad_fun = jax.grad(ridge_objective) diff --git a/tests/import_test.py b/tests/import_test.py index 0b6882c1..0c291583 100644 --- a/tests/import_test.py +++ b/tests/import_test.py @@ -32,6 +32,10 @@ def test_projection(self): jaxopt.projection.projection_simplex from jaxopt.projection import projection_simplex + def test_support(self): + jaxopt.support.support_all + from jaxopt.support import support_all + def test_tree_util(self): from jaxopt.tree_util import tree_vdot diff --git a/tests/proximal_gradient_test.py b/tests/proximal_gradient_test.py index cf12386d..24d09a91 100644 --- a/tests/proximal_gradient_test.py +++ b/tests/proximal_gradient_test.py @@ -25,6 +25,7 @@ from jaxopt import objective from jaxopt import projection from jaxopt import prox +from jaxopt import support from jaxopt import ProximalGradient from jaxopt._src import test_util from jaxopt import tree_util as tu @@ -118,6 +119,34 @@ def wrapper(hyperparams_prox): jac_prox = jax.jacrev(wrapper)(lam) self.assertArraysAllClose(jac_num, jac_prox, atol=1e-3) + def test_lasso_implicit_diff_sparse(self): + X, y = datasets.make_regression(n_samples=10, n_features=10, + n_informative=3, random_state=0) + lam = 10.0 + data = (X, y) + + fun = objective.least_squares + jac_num = test_util.lasso_skl_jac(X, y, lam) + w_skl = test_util.lasso_skl(X, y, lam) + self.assertArraysEqual( + support.support_nonzero(jac_num), + support.support_nonzero(w_skl) + ) + + pg = ProximalGradient(fun=fun, prox=prox.prox_lasso, + support=support.support_nonzero, tol=1e-3, + maxiter=200, acceleration=True, implicit_diff=True) + + def wrapper(hyperparams_prox): + return pg.run(w_skl, hyperparams_prox, data).params + + jac_prox = jax.jacrev(wrapper)(lam) + self.assertArraysEqual( + support.support_nonzero(jac_prox), + support.support_nonzero(w_skl) + ) + self.assertArraysAllClose(jac_num, jac_prox, atol=1e-3) + def test_stepsize_schedule(self): X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) fun = objective.least_squares # fun(params, data) diff --git a/tests/support_test.py b/tests/support_test.py new file mode 100644 index 00000000..e35ecaf2 --- /dev/null +++ b/tests/support_test.py @@ -0,0 +1,52 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +from jaxopt import support +from jaxopt._src import test_util + +import numpy as onp + + +class SupportTest(test_util.JaxoptTestCase): + + def test_support_all(self): + rng = onp.random.RandomState(0) + x = rng.rand(20) * 2 - 1 + supp = onp.ones_like(x) + self.assertArraysEqual(support.support_all(x), supp) + + def test_support_nonzero(self): + rng = onp.random.RandomState(0) + x = rng.rand(20) * 2 - 1 + x = onp.where(x > 0, x, 0) + supp = (x != 0).astype(x.dtype) + self.assertArraysEqual(support.support_nonzero(x), supp) + + def test_support_group_nonzero(self): + rng = onp.random.RandomState(0) + + x = rng.rand(20) * 2 - 1 + x = onp.where(x > 0, x, 0) + self.assertFalse(onp.all(x == 0)) + supp = onp.ones_like(x) + self.assertArraysEqual(support.support_group_nonzero(x), supp) + + x = onp.zeros(20) + supp = onp.zeros_like(x) + self.assertArraysEqual(support.support_group_nonzero(x), supp) + +if __name__ == '__main__': + absltest.main()