Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
less requirements.txt | grep 'pytest\|chex' | xargs -i -t pip install {}
- name: Run tests
run: |
pytest -n auto -vv -m "not benchmark" --cov=blackjax --cov-report=xml --cov-report=term tests
pytest -vv -m benchmark --cov=blackjax --cov-report=xml --cov-report=term tests
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
Expand Down
5 changes: 4 additions & 1 deletion blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def mclmc_find_L_and_step_size(
num_effective_samples=150,
diagonal_preconditioning=True,
params=None,
Lfactor=0.4,
):
"""
Finds the optimal value of the parameters for the MCLMC algorithm.
Expand Down Expand Up @@ -82,6 +83,8 @@ def mclmc_find_L_and_step_size(
Whether to do diagonal preconditioning (i.e. a mass matrix)
params
Initial params to start tuning from (optional)
Lfactor
The factor scaling the estimated autocorrelation length to obtain momentum decoherence length L.

Returns
-------
Expand Down Expand Up @@ -136,7 +139,7 @@ def mclmc_find_L_and_step_size(

if num_steps3 >= 2: # at least 2 samples for ESS estimation
state, params = make_adaptation_L(
mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=0.4
mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=Lfactor
)(state, params, num_steps, part2_key)
total_num_tuning_integrator_steps += num_steps3

Expand Down
142 changes: 141 additions & 1 deletion blackjax/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Utility functions for BlackJax."""

from functools import partial
from typing import Callable, Union
from typing import Callable, NamedTuple, Union

import jax.numpy as jnp
from jax import jit, lax
Expand Down Expand Up @@ -314,3 +314,143 @@ def incremental_value_update(
)
total += weight
return total, average


def thin_algorithm(
sampling_algorithm: SamplingAlgorithm,
thinning: int = 1,
info_transform: Callable = lambda x: x,
) -> SamplingAlgorithm:
"""
Return a new sampling algorithm that performs `thinning` iterations of the given algorithm,
meaning only one state is returned every `thinning` steps.
This is useful to reduce computation and memory cost of high throughput samplers, especially in high dimension.

Parameters
----------
sampling_algorithm: SamplingAlgorithm
The sampling algorithm to thin.
thinning: int
The number of algorithm step to be performed before returning the state.
info_transform: Callable
A function defining how to aggregate algorithm informations across the `thinning` steps.
By default return all of them.

Returns
-------
SamplingAlgorithm
A thinned version of the sampling algorithm.

Example
-------
.. code::

logdf = lambda x: -(x**2).sum()
init_pos = jnp.ones(2)
init_key, run_key = jr.split(jr.key(43), 2)

state = blackjax.mcmc.mclmc.init(
position=init_pos,
logdensity_fn=logdf,
rng_key=init_key
)

sampler = blackjax.mclmc(
logdensity_fn=logdf,
L=L,
step_size=step_size,
inverse_mass_matrix=inverse_mass_matrix,
)

sampler = thin_algorithm(
sampler,
thinning=16,
info_transform=lambda info: tree.map(jnp.mean, info),
)

state, history = run_inference_algorithm(
rng_key=run_key,
initial_state=state,
inference_algorithm=sampler,
num_steps=100,
)
"""

def step_fn(rng_key: PRNGKey, state: NamedTuple) -> tuple[NamedTuple, NamedTuple]:
step = lambda state, rng_key: sampling_algorithm.step(rng_key, state)
keys = split(rng_key, thinning)
state, info = lax.scan(step, state, keys)
return state, info_transform(info)

return SamplingAlgorithm(sampling_algorithm.init, step_fn)


def thin_kernel(
kernel: Callable, thinning: int = 1, info_transform=lambda x: x
) -> Callable:
"""
Return a thinned version of a kernel that runs the kernel `thinning` times before returning the state.
This is useful to reduce computation and memory cost of high throughput samplers, especially in high dimension.

Parameters
----------
kernel: Callable
The kernel to thin.
thinning: int
The number of kernel step to be performed before returning the state.
info_transform: Callable
A function defining how to aggregate algorithm informations across the `thinning` steps.
By default return all of them.

Returns
-------
Callable
A thinned version of the kernel.


Example
-------
.. code::

logdf = lambda x: -(x**2).sum()
init_pos = jnp.ones(2)
init_key, tune_key = jr.split(jr.key(42), 2)

state = blackjax.mcmc.mclmc.init(
position=init_pos,
logdensity_fn=logdf,
rng_key=init_key
)

kernel = lambda inverse_mass_matrix: thin_kernel(
blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdf,
integrator=isokinetic_mclachlan,
inverse_mass_matrix=inverse_mass_matrix,
),

# Return every 16th state, especially decreasing computation and memory cost
# when estimating high dimensional autocorrelation length during tuning.
thinning = 16

# Adequatly aggregate info.energy_change
info_transform=lambda info: tree.map(lambda x: (x**2).mean()**.5, info)
)

state, params, n_steps = blackjax.mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=100,
state=state,
rng_key=tune_key,
)
"""

def thinned_kernel(
rng_key: PRNGKey, state: NamedTuple, *args, **kwargs
) -> tuple[NamedTuple, NamedTuple]:
step = lambda state, rng_key: kernel(rng_key, state, *args, **kwargs)
keys = split(rng_key, thinning)
state, info = lax.scan(step, state, keys)
return state, info_transform(info)

return thinned_kernel
135 changes: 130 additions & 5 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
from functools import partial

import chex
import jax
import jax.numpy as jnp
import numpy as np
from absl.testing import absltest, parameterized
from jax import jit
from jax import numpy as jnp
from jax import random as jr
from jax import tree, vmap

import blackjax
from blackjax.util import run_inference_algorithm, store_only_expectation_values
from blackjax.util import (
run_inference_algorithm,
store_only_expectation_values,
thin_algorithm,
thin_kernel,
)


class RunInferenceAlgorithmTest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.key(42)
self.key = jr.key(42)
self.algorithm = blackjax.hmc(
logdensity_fn=self.logdensity_fn,
inverse_mass_matrix=jnp.eye(2),
Expand Down Expand Up @@ -41,7 +51,7 @@ def logdensity_fn(x):
10,
)

init_key, state_key, run_key = jax.random.split(self.key, 3)
init_key, state_key, run_key = jr.split(self.key, 3)
initial_state = blackjax.mcmc.mclmc.init(
position=initial_position, logdensity_fn=logdensity_fn, rng_key=state_key
)
Expand Down Expand Up @@ -106,5 +116,120 @@ def logdensity_fn(x):
return -0.5 * jnp.sum(jnp.square(x))


class ThinInferenceAlgorithmTest(chex.TestCase):
# logdf = lambda self, x: - (x**2).sum(-1) / 2 # Gaussian
logdf = (
lambda self, x: -((x[::2] - 1) ** 2 + (x[1::2] - x[::2] ** 2) ** 2).sum(-1) / 2
) # Rosenbrock
d = 2
init_pos = jnp.ones(d)
rng_keys = jr.split(jr.key(42), 2)
num_steps = 10_000

def warmup(self, rng_key, num_steps, thinning: int = 1):
from blackjax.mcmc.integrators import isokinetic_mclachlan

init_key, tune_key = jr.split(rng_key, 2)

state = blackjax.mcmc.mclmc.init(
position=self.init_pos, logdensity_fn=self.logdf, rng_key=init_key
)

if thinning == 1:
kernel = lambda inverse_mass_matrix: blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=self.logdf,
integrator=isokinetic_mclachlan,
inverse_mass_matrix=inverse_mass_matrix,
)
else:
kernel = lambda inverse_mass_matrix: thin_kernel(
blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=self.logdf,
integrator=isokinetic_mclachlan,
inverse_mass_matrix=inverse_mass_matrix,
),
thinning=thinning,
info_transform=lambda info: tree.map(
lambda x: (x**2).mean() ** 0.5, info
),
)

state, config, n_steps = blackjax.mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
state=state,
rng_key=tune_key,
# frac_tune3=0.
)
n_steps *= thinning
config = config._replace(
L=config.L * thinning
) # NOTE: compensate L for thinning
return state, config, n_steps

def run(self, rng_key, state, config, num_steps, thinning: int = 1):
sampler = blackjax.mclmc(
self.logdf,
L=config.L,
step_size=config.step_size,
inverse_mass_matrix=config.inverse_mass_matrix,
)
if thinning != 1:
sampler = thin_algorithm(
sampler,
thinning=thinning,
info_transform=lambda info: tree.map(jnp.mean, info),
)

state, history = run_inference_algorithm(
rng_key=rng_key,
initial_state=state,
inference_algorithm=sampler,
num_steps=num_steps,
# progress_bar=True,
)
return state, history

def test_thin(self):
# Test thin kernel in warmup
state, config, n_steps = jit(
vmap(partial(self.warmup, num_steps=self.num_steps, thinning=1))
)(self.rng_keys)
config = tree.map(lambda x: jnp.median(x, 0), config)
state_thin, config_thin, n_steps_thin = jit(
vmap(partial(self.warmup, num_steps=self.num_steps, thinning=4))
)(self.rng_keys)
config_thin = tree.map(lambda x: jnp.median(x, 0), config_thin)

rtol = 5e-1
np.testing.assert_allclose(config_thin.L, config.L, rtol=rtol)
np.testing.assert_allclose(config_thin.step_size, config.step_size, rtol=rtol)
np.testing.assert_allclose(
config_thin.inverse_mass_matrix, config.inverse_mass_matrix, rtol=rtol
)

# Test thin algorithm in run
state, history = jit(
vmap(partial(self.run, config=config, num_steps=self.num_steps, thinning=1))
)(self.rng_keys, state)
samples = jnp.concatenate(history[0].position)
state_thin, history_thin = jit(
vmap(
partial(
self.run, config=config_thin, num_steps=self.num_steps, thinning=4
)
)
)(self.rng_keys, state_thin)
samples_thin = jnp.concatenate(history_thin[0].position)

rtol, atol = 1e-1, 1e-1
np.testing.assert_allclose(
samples_thin.mean(0), samples.mean(0), rtol=rtol, atol=atol
)
np.testing.assert_allclose(
jnp.cov(samples_thin.T), jnp.cov(samples.T), rtol=rtol, atol=atol
)


if __name__ == "__main__":
absltest.main()