Skip to content

Sparse root VJP for Lasso penalty #274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
101 changes: 101 additions & 0 deletions benchmarks/sparse_root_prox_gradient_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Implicit differentiation of lasso.
==================================
"""

import time
from absl import app
from absl import flags

import jax
import jax.numpy as jnp

from jaxopt import objective
from jaxopt import prox
from jaxopt import ProximalGradient
from jaxopt import support
from jaxopt._src import linear_solve

from sklearn import datasets
from sklearn import model_selection


def outer_objective(theta, init_inner, data, support, implicit_diff_solve):
"""Validation loss."""
X_tr, X_val, y_tr, y_val = data
# We use the bijective mapping lam = jnp.exp(theta) to ensure positivity.
lam = jnp.exp(theta)

solver = ProximalGradient(
fun=objective.least_squares,
support=support,
prox=prox.prox_lasso,
implicit_diff=True,
implicit_diff_solve=implicit_diff_solve,
maxiter=5000,
tol=1e-10)

# The format is run(init_params, hyperparams_prox, *args, **kwargs)
# where *args and **kwargs are passed to `fun`.
w_fit = solver.run(init_inner, lam, (X_tr, y_tr)).params

y_pred = jnp.dot(X_val, w_fit)
loss_value = jnp.mean((y_pred - y_val) ** 2)

# We return w_fit as auxiliary data.
# Auxiliary data is stored in the optimizer state (see below).
return loss_value, w_fit


X, y = datasets.make_regression(
n_samples=100, n_features=6000, n_informative=10, random_state=0)

n_samples = X.shape[0]
data = model_selection.train_test_split(X, y, test_size=0.33, random_state=0)

exp_lam_max = jnp.max(jnp.abs(X.T @ y))
lam = jnp.log(exp_lam_max / (100 * n_samples))

init_inner = jnp.zeros(X.shape[1])


t0 = time.time()
grad = jax.grad(outer_objective, has_aux=True)
gradient, coef_ = grad(
lam, init_inner, data, support.support_all, linear_solve.solve_cg)
gradient.block_until_ready()
delta_t = time.time() - t0
desc='Gradients w/o support, CG'
print(f'{desc} ({delta_t:.3f} sec.): {gradient} ')

t0 = time.time()
grad = jax.grad(outer_objective, has_aux=True)
gradient, coef_ = grad(
lam, init_inner, data, support.support_all, linear_solve.solve_normal_cg)
gradient.block_until_ready()
delta_t = time.time() - t0
desc='Gradients w/o support, normal CG'
print(f'{desc} ({delta_t:.3f} sec.): {gradient} ')

t0 = time.time()
grad = jax.grad(outer_objective, has_aux=True)
gradient, coef_ = grad(
lam, init_inner, data, support.support_nonzero, linear_solve.solve_cg)
gradient.block_until_ready()
delta_t = time.time() - t0
desc='Gradients w/ masked support'
print(f'{desc} ({delta_t:.3f} sec.): {gradient}')
115 changes: 115 additions & 0 deletions benchmarks/sparse_root_vjp_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Benchmark VJP of Lasso, with and without specifying the support."""

import time
import numpy as onp
import jax
import jax.numpy as jnp

from typing import Sequence
from absl import app
from sklearn import datasets

from jaxopt import implicit_diff as idf
from jaxopt._src import linear_solve
from jaxopt import objective
from jaxopt import prox
from jaxopt import support
from jaxopt import tree_util
from jaxopt._src import test_util


def lasso_optimality_fun(params, lam, X, y):
step = params - jax.grad(objective.least_squares)(params, (X, y))
return prox.prox_lasso(step, l1reg=lam, scaling=1.) - params


def get_vjp(lam, X, y, sol, support_fun, maxiter, solve_fn=linear_solve.solve_cg):
def solve(matvec, b):
return solve_fn(matvec, b, tol=1e-6, maxiter=maxiter)

vjp = lambda g: idf.root_vjp(optimality_fun=lasso_optimality_fun,
support_fun=support_fun,
sol=sol,
args=(lam, X, y),
cotangent=g,
solve=solve)[0]
return vjp


def benchmark_vjp(vjp, X, supp=None, desc=''):
t0 = time.time()
result = jax.vmap(vjp)(jnp.eye(X.shape[1]))
result.block_until_ready()
delta_t = time.time() - t0

size_support = onp.sum(result != 0)

if supp is not None:
result = result[supp]
print(f'{desc} ({delta_t:.3f} sec.): {result} '
f'(size of the support: {size_support:d})')


def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")

X, y = datasets.make_regression(n_samples=100, n_features=1000,
n_informative=5, random_state=0)
print(f'Number of samples: {X.shape[0]:d}')
print(f'Number of features: {X.shape[1]:d}')

lam = 1e-3 * onp.max(onp.abs(X.T @ y)) / X.shape[0]
print(f'Value of lambda: {lam:.5f}')

sol = test_util.lasso_skl(X, y, lam, tol=1e-6, fit_intercept=False)
supp = (sol != 0)
print(f'Size of the support of the solution: {supp.sum():d}')

optimality = tree_util.tree_l2_norm(lasso_optimality_fun(sol, lam, X, y))
print(f'Optimality of the solution (L2-norm): {optimality:.5f}')

jac_num = test_util.lasso_skl_jac(X, y, lam, tol=1e-8, eps=1e-8)
print(f'Numerical Jacobian: {jac_num[supp]}')

# Compute the Jacobian wrt. lambda, without using the information about the
# support of the solution. This is the default behavior in JAXopt, and
# requires solving a linear system with a 1000x1000 dense matrix. Ignoring
# the support of the solution with CG leads to an inacurrate Jacobian.
vjp = get_vjp(lam, X, y, sol, support.support_all, maxiter=1000)
benchmark_vjp(vjp, X, supp=supp, desc='Jacobian w/o support, CG')

vjp = get_vjp(lam, X, y, sol, support.support_all, maxiter=1000,
solve_fn=linear_solve.solve_normal_cg)
benchmark_vjp(vjp, X, supp=supp, desc='Jacobian w/o support, normal CG')

# Compute the Jacobian wrt. lambda, restricting the data X and the solution
# to the support of the solution. This requires solving a linear system with
# a 4x4 dense matrix. Restricting the support this way is not jit-friendly,
# and will require new compilation if the size of the support changes.
vjp = get_vjp(lam, X[:, supp], y, sol[supp], support.support_all, maxiter=1000)
benchmark_vjp(vjp, X[:, supp], desc='Jacobian w/ restricted support')

# Compute the Jacobian wrt. lambda, by masking the linear system to solve
# with the support of the solution. This requires solving a linear system with
# a 1000x1000 sparse matrix. Masking with the support is jit-friendly.
vjp = get_vjp(lam, X, y, sol, support.support_nonzero, maxiter=1000)
benchmark_vjp(vjp, X, supp=supp, desc='Jacobian w/ masked support')


if __name__ == '__main__':
app.run(main)
25 changes: 25 additions & 0 deletions docs/implicit_diff.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,26 @@ We can also compose ``ridge_reg_solution`` with other functions::
* :ref:`sphx_glr_auto_examples_implicit_diff_lasso_implicit_diff.py`
* :ref:`sphx_glr_auto_examples_implicit_diff_sparse_coding.py`

Non-smooth functions
--------------------

When the function :math:`f(x, \theta)` is non-smooth (e.g., in Lasso), implicit
differentiation can still be applied to compute the Jacobian
:math:`\partial x^\star(\theta)`. However, the linear system to solve to obtain
the Jacobian (or, alternatively, the vector-Jacobian product) must be restricted
to the (generalized) *support* :math:`S` of the solution :math:`x^\star`. To
give a hint to the linear solver called in ``root_vjp``, you may specify a
support function ``support`` that will restrict the linear system to the support
of the solution.

The ``support`` function must return a pytree with the same structure and dtype
as the solution, where ``support(x)`` is equal to 1 for the coordinates :math:`j`
in the support (:math:`x_{j} \in S`), and 0 otherwise. The support function
depends on the non-smooth function being optimized; see :ref:`support functions
<support_functions>` for examples of support functions. Note that the
support function merely masks out coordinates outside of the support, making it
fully compatible with ``jit`` compilation.

Custom solvers
--------------

Expand Down Expand Up @@ -92,3 +112,8 @@ of roots of functions.
<https://arxiv.org/abs/2105.15183>`_,
Mathieu Blondel, Quentin Berthet, Marco Cuturi, Roy Frostig, Stephan Hoyer, Felipe Llinares-López, Fabian Pedregosa, Jean-Philippe Vert.
ArXiv preprint.

* `Implicit Differentiation for Fast Hyperparameter Selection in Non-Smooth Convex Learning
<https://www.jmlr.org/papers/volume23/21-0486/21-0486.pdf>`_,
Quentin Bertrand, Quentin Klopfenstein, Mathurin Massias, Mathieu Blondel, Samuel Vaiter, Alexandre Gramfort, Joseph Salmon.
Journal of Machine Learning Research (JMLR).
38 changes: 38 additions & 0 deletions docs/non_smooth.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ and autodiff of unrolled iterations if ``implicit_diff=False``. See the
* :ref:`sphx_glr_auto_examples_implicit_diff_lasso_implicit_diff.py`
* :ref:`sphx_glr_auto_examples_implicit_diff_sparse_coding.py`

When using implicit differentiation, you can optionally specify a support
function ``support`` to give a hint to the linear solver called in ``root_vjp``
and only solve the linear system restricted to the support of the solution::

from jaxopt.support import support_nonzero

def solution(l1reg):
pg = ProximalGradient(fun=least_squares, prox=prox_lasso,
support=support_nonzero, implicit_diff=True)
return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params

# Both the solution & the Jacobian have the same support
print(solution(l1reg))
print(jax.jacobian(solution)(l1reg))

See the :ref:`implicit differentiation <implicit_diff>` section for more details.

.. _block_coordinate_descent:

Block coordinate descent
Expand Down Expand Up @@ -134,3 +151,24 @@ The following operators are available.
jaxopt.prox.prox_group_lasso
jaxopt.prox.prox_ridge
jaxopt.prox.prox_non_negative_ridge

.. _support_functions:

Support functions
-----------------

Support functions of the form :math:`S(x)` that returns 1 for all the
coordinates of :math:`x` in the support, and 0 otherwise:

.. math::

S(x)_{j} := \begin{cases} 1 & \textrm{if $x_{j} \in S$} \\ 0 & \textrm{otherwise} \end{cases}

The following support functions are available.

.. autosummary::
:toctree: _autosummary

jaxopt.support.support_all
jaxopt.support.support_nonzero
jaxopt.support.support_group_nonzero
2 changes: 2 additions & 0 deletions jaxopt/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,12 @@ def run(self,
run = self._run

if getattr(self, "implicit_diff", True):
support_fun = getattr(self, "support", None)
reference_signature = getattr(self, "reference_signature", None)
decorator = idf.custom_root(
self.optimality_fun,
has_aux=True,
support_fun=support_fun,
solve=self.implicit_diff_solve,
reference_signature=reference_signature)
run = decorator(run)
Expand Down
Loading