diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index fa644898a..890a30e2a 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -52,6 +52,7 @@ def mclmc_find_L_and_step_size( num_effective_samples=150, diagonal_preconditioning=True, params=None, + Lfactor=0.4, ): """ Finds the optimal value of the parameters for the MCLMC algorithm. @@ -82,6 +83,8 @@ def mclmc_find_L_and_step_size( Whether to do diagonal preconditioning (i.e. a mass matrix) params Initial params to start tuning from (optional) + Lfactor + The factor scaling the estimated autocorrelation length to obtain momentum decoherence length L. Returns ------- @@ -136,7 +139,7 @@ def mclmc_find_L_and_step_size( if num_steps3 >= 2: # at least 2 samples for ESS estimation state, params = make_adaptation_L( - mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=0.4 + mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=Lfactor )(state, params, num_steps, part2_key) total_num_tuning_integrator_steps += num_steps3 diff --git a/blackjax/util.py b/blackjax/util.py index 8cdcd45ee..1effcbdc1 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -1,7 +1,7 @@ """Utility functions for BlackJax.""" from functools import partial -from typing import Callable, Union +from typing import Callable, NamedTuple, Union import jax.numpy as jnp from jax import jit, lax @@ -314,3 +314,143 @@ def incremental_value_update( ) total += weight return total, average + + +def thin_algorithm( + sampling_algorithm: SamplingAlgorithm, + thinning: int = 1, + info_transform: Callable = lambda x: x, +) -> SamplingAlgorithm: + """ + Return a new sampling algorithm that performs `thinning` iterations of the given algorithm, + meaning only one state is returned every `thinning` steps. + This is useful to reduce computation and memory cost of high throughput samplers, especially in high dimension. + + Parameters + ---------- + sampling_algorithm: SamplingAlgorithm + The sampling algorithm to thin. + thinning: int + The number of algorithm step to be performed before returning the state. + info_transform: Callable + A function defining how to aggregate algorithm informations across the `thinning` steps. + By default return all of them. + + Returns + ------- + SamplingAlgorithm + A thinned version of the sampling algorithm. + + Example + ------- + .. code:: + + logdf = lambda x: -(x**2).sum() + init_pos = jnp.ones(2) + init_key, run_key = jr.split(jr.key(43), 2) + + state = blackjax.mcmc.mclmc.init( + position=init_pos, + logdensity_fn=logdf, + rng_key=init_key + ) + + sampler = blackjax.mclmc( + logdensity_fn=logdf, + L=L, + step_size=step_size, + inverse_mass_matrix=inverse_mass_matrix, + ) + + sampler = thin_algorithm( + sampler, + thinning=16, + info_transform=lambda info: tree.map(jnp.mean, info), + ) + + state, history = run_inference_algorithm( + rng_key=run_key, + initial_state=state, + inference_algorithm=sampler, + num_steps=100, + ) + """ + + def step_fn(rng_key: PRNGKey, state: NamedTuple) -> tuple[NamedTuple, NamedTuple]: + step = lambda state, rng_key: sampling_algorithm.step(rng_key, state) + keys = split(rng_key, thinning) + state, info = lax.scan(step, state, keys) + return state, info_transform(info) + + return SamplingAlgorithm(sampling_algorithm.init, step_fn) + + +def thin_kernel( + kernel: Callable, thinning: int = 1, info_transform=lambda x: x +) -> Callable: + """ + Return a thinned version of a kernel that runs the kernel `thinning` times before returning the state. + This is useful to reduce computation and memory cost of high throughput samplers, especially in high dimension. + + Parameters + ---------- + kernel: Callable + The kernel to thin. + thinning: int + The number of kernel step to be performed before returning the state. + info_transform: Callable + A function defining how to aggregate algorithm informations across the `thinning` steps. + By default return all of them. + + Returns + ------- + Callable + A thinned version of the kernel. + + + Example + ------- + .. code:: + + logdf = lambda x: -(x**2).sum() + init_pos = jnp.ones(2) + init_key, tune_key = jr.split(jr.key(42), 2) + + state = blackjax.mcmc.mclmc.init( + position=init_pos, + logdensity_fn=logdf, + rng_key=init_key + ) + + kernel = lambda inverse_mass_matrix: thin_kernel( + blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=logdf, + integrator=isokinetic_mclachlan, + inverse_mass_matrix=inverse_mass_matrix, + ), + + # Return every 16th state, especially decreasing computation and memory cost + # when estimating high dimensional autocorrelation length during tuning. + thinning = 16 + + # Adequatly aggregate info.energy_change + info_transform=lambda info: tree.map(lambda x: (x**2).mean()**.5, info) + ) + + state, params, n_steps = blackjax.mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=100, + state=state, + rng_key=tune_key, + ) + """ + + def thinned_kernel( + rng_key: PRNGKey, state: NamedTuple, *args, **kwargs + ) -> tuple[NamedTuple, NamedTuple]: + step = lambda state, rng_key: kernel(rng_key, state, *args, **kwargs) + keys = split(rng_key, thinning) + state, info = lax.scan(step, state, keys) + return state, info_transform(info) + + return thinned_kernel diff --git a/tests/test_util.py b/tests/test_util.py index 78198f013..c38e0d14d 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,16 +1,26 @@ +from functools import partial + import chex -import jax -import jax.numpy as jnp +import numpy as np from absl.testing import absltest, parameterized +from jax import jit +from jax import numpy as jnp +from jax import random as jr +from jax import tree, vmap import blackjax -from blackjax.util import run_inference_algorithm, store_only_expectation_values +from blackjax.util import ( + run_inference_algorithm, + store_only_expectation_values, + thin_algorithm, + thin_kernel, +) class RunInferenceAlgorithmTest(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.key(42) + self.key = jr.key(42) self.algorithm = blackjax.hmc( logdensity_fn=self.logdensity_fn, inverse_mass_matrix=jnp.eye(2), @@ -41,7 +51,7 @@ def logdensity_fn(x): 10, ) - init_key, state_key, run_key = jax.random.split(self.key, 3) + init_key, state_key, run_key = jr.split(self.key, 3) initial_state = blackjax.mcmc.mclmc.init( position=initial_position, logdensity_fn=logdensity_fn, rng_key=state_key ) @@ -106,5 +116,139 @@ def logdensity_fn(x): return -0.5 * jnp.sum(jnp.square(x)) +class ThinInferenceAlgorithmTest(chex.TestCase): + def setUp(self): + super().setUp() + self.logdf = lambda x: -(x**2).sum(-1) / 2 # Gaussian + # self.logdf = ( + # lambda x: -((x[::2] - 1) ** 2 + (x[1::2] - x[::2] ** 2) ** 2).sum(-1) / 2 + # ) # Rosenbrock + dim = 2 + self.init_pos = jnp.ones(dim) + self.rng_keys = jr.split(jr.key(42), 2) + self.num_steps = 10_000 + + def warmup(self, rng_key, num_steps, thinning: int = 1): + from blackjax.mcmc.integrators import isokinetic_mclachlan + + init_key, tune_key = jr.split(rng_key, 2) + + state = blackjax.mcmc.mclmc.init( + position=self.init_pos, logdensity_fn=self.logdf, rng_key=init_key + ) + + if thinning == 1: + kernel = lambda inverse_mass_matrix: blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=self.logdf, + integrator=isokinetic_mclachlan, + inverse_mass_matrix=inverse_mass_matrix, + ) + else: + kernel = lambda inverse_mass_matrix: thin_kernel( + blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=self.logdf, + integrator=isokinetic_mclachlan, + inverse_mass_matrix=inverse_mass_matrix, + ), + thinning=thinning, + info_transform=lambda info: tree.map( + lambda x: (x**2).mean() ** 0.5, info + ), + ) + + state, config, n_steps = blackjax.mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=state, + rng_key=tune_key, + # frac_tune3=0. + ) + n_steps *= thinning + config = config._replace( + L=config.L * thinning + ) # NOTE: compensate L for thinning + return state, config, n_steps + + def run_algo(self, rng_key, state, config, num_steps, thinning: int = 1): + sampler = blackjax.mclmc( + self.logdf, + L=config.L, + step_size=config.step_size, + inverse_mass_matrix=config.inverse_mass_matrix, + ) + if thinning != 1: + sampler = thin_algorithm( + sampler, + thinning=thinning, + info_transform=lambda info: tree.map(jnp.mean, info), + ) + + state, history = run_inference_algorithm( + rng_key=rng_key, + initial_state=state, + inference_algorithm=sampler, + num_steps=num_steps, + # progress_bar=True, + ) + return state, history + + def test_thin(self): + """ + Compare results obtained from thinning kernel or algorithm vs. no thinning. + """ + # Test thin kernel in warmup + state, config, n_steps = jit( + vmap(partial(self.warmup, num_steps=self.num_steps, thinning=1)) + )(self.rng_keys) + config = tree.map(lambda x: jnp.median(x, 0), config) + state_thin, config_thin, n_steps_thin = jit( + vmap(partial(self.warmup, num_steps=self.num_steps, thinning=4)) + )(self.rng_keys) + config_thin = tree.map(lambda x: jnp.median(x, 0), config_thin) + + # Assert that the found parameters are close + rtol, atol = 1e-1, 1e-1 + np.testing.assert_allclose(config_thin.L, config.L, rtol=rtol, atol=atol) + np.testing.assert_allclose( + config_thin.step_size, config.step_size, rtol=rtol, atol=atol + ) + np.testing.assert_allclose( + config_thin.inverse_mass_matrix, + config.inverse_mass_matrix, + rtol=rtol, + atol=atol, + ) + + # Test thin algorithm in run_algo + state, history = jit( + vmap( + partial( + self.run_algo, config=config, num_steps=self.num_steps, thinning=1 + ) + ) + )(self.rng_keys, state) + samples = jnp.concatenate(history[0].position) + state_thin, history_thin = jit( + vmap( + partial( + self.run_algo, + config=config_thin, + num_steps=self.num_steps, + thinning=4, + ) + ) + )(self.rng_keys, state_thin) + samples_thin = jnp.concatenate(history_thin[0].position) + + # Assert that the sample statistics are close + rtol, atol = 1e-1, 1e-1 + np.testing.assert_allclose( + samples_thin.mean(0), samples.mean(0), rtol=rtol, atol=atol + ) + np.testing.assert_allclose( + jnp.cov(samples_thin.T), jnp.cov(samples.T), rtol=rtol, atol=atol + ) + + if __name__ == "__main__": absltest.main()