-
Couldn't load subscription status.
- Fork 121
Thin kernel and sampling algorithm #791
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
logdf = lambda x: -(x**2).sum()
init_pos = jnp.ones(2)
init_key, tune_key, run_key = jr.split(jr.key(42), 3)
state = blackjax.mcmc.mclmc.init(
position=init_pos,
logdensity_fn=logdf,
rng_key=init_key
)
kernel = lambda inverse_mass_matrix: thin_kernel(
blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdf,
integrator=isokinetic_mclachlan,
inverse_mass_matrix=inverse_mass_matrix,
),
thinning = 16
# Adequately aggregate info.energy_change
info_transform=lambda info: tree.map(lambda x: (x**2).mean()**.5, info)
)
state, params, n_steps = blackjax.mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=100,
state=state,
rng_key=tune_key,
)
sampler = blackjax.mclmc(
logdensity_fn=logdf,
L=params.L,
step_size=params.step_size,
inverse_mass_matrix=params.inverse_mass_matrix,
)
sampler = thin_algorithm(
sampler,
thinning=16,
info_transform=lambda info: tree.map(jnp.mean, info),
)
state, history = run_inference_algorithm(
rng_key=run_key,
initial_state=state,
inference_algorithm=sampler,
num_steps=100,
)
|
|
Overall LGTM, could you add some test? |
|
Ok, Python 3.12 test was passing a few weeks ago, smthg seems to have broken between pytest and xdist.
|
|
Hello @junpenglao and @reubenharry, Then I added test as |
PR discussed in #738
Add SamplingAlgorithm and kernel transformations making them thinned. They take a
thinninginteger and a SamplingAlgorithm/kernel and return the same SamplingAlgorithm/kernel but iteratedthinningtimes.This is useful to reduce computation and memory cost of high throughput samplers, especially in high dimension. While the
thin_algorithmfunction operates on top_level_api SamplingAlgorithm, thethin_kernelversion is relevant for adaptation algorithms. For instance, the estimation of autocorrelation length, for tuning momentum decoherence length inmclmc_adaptation, using the states from every step is computationally prohibitive in high dimension, see Subsampling for MCLMC tuning #738.Both transformations have an additional
info_transformCallable parameter that defines how to aggregate the sampler informations across thethinningsteps. For instance, we might want to average the logdensities, and to rootmeansquare the energy_changes, which can be easily performed withtree.maportree.map_with_path.maincommit;pre-commitis installed and configured on your machine, and you ran it before opening the PR;