Skip to content

Conversation

@hsimonfroy
Copy link
Contributor

PR discussed in #738

  • Add SamplingAlgorithm and kernel transformations making them thinned. They take a thinning integer and a SamplingAlgorithm/kernel and return the same SamplingAlgorithm/kernel but iterated thinning times.

  • This is useful to reduce computation and memory cost of high throughput samplers, especially in high dimension. While the thin_algorithm function operates on top_level_api SamplingAlgorithm, the thin_kernel version is relevant for adaptation algorithms. For instance, the estimation of autocorrelation length, for tuning momentum decoherence length in mclmc_adaptation, using the states from every step is computationally prohibitive in high dimension, see Subsampling for MCLMC tuning #738.

  • Both transformations have an additional info_transform Callable parameter that defines how to aggregate the sampler informations across the thinning steps. For instance, we might want to average the logdensities, and to rootmeansquare the energy_changes, which can be easily performed with tree.map or tree.map_with_path.

  • We should be able to understand what the PR does from its title only;
  • There is a high-level description of the changes;
  • There are links to all the relevant issues, discussions and PRs;
  • The branch is rebased on the latest main commit;
  • Commit messages follow these guidelines;
  • The code respects the current naming conventions;
  • Docstrings follow the numpy style guide
  • pre-commit is installed and configured on your machine, and you ran it before opening the PR;
  • There are tests covering the changes;
  • The doc is up-to-date;

@hsimonfroy
Copy link
Contributor Author

hsimonfroy commented May 18, 2025

  • Here is an example of how thin_algorithm and thin_kernel could be used for MCLMC:
logdf = lambda x: -(x**2).sum()
init_pos = jnp.ones(2)
init_key, tune_key, run_key = jr.split(jr.key(42), 3)

state = blackjax.mcmc.mclmc.init(
            position=init_pos,
            logdensity_fn=logdf,
            rng_key=init_key
            )

kernel = lambda inverse_mass_matrix: thin_kernel(
            blackjax.mcmc.mclmc.build_kernel(
                                logdensity_fn=logdf,
                                integrator=isokinetic_mclachlan,
                                inverse_mass_matrix=inverse_mass_matrix,
                                ),
            thinning = 16
            # Adequately aggregate info.energy_change
            info_transform=lambda info: tree.map(lambda x: (x**2).mean()**.5, info)
            )

state, params, n_steps = blackjax.mclmc_find_L_and_step_size(
            mclmc_kernel=kernel,
            num_steps=100,
            state=state,
            rng_key=tune_key,
            )
    
sampler = blackjax.mclmc(
            logdensity_fn=logdf,
            L=params.L,
            step_size=params.step_size,
            inverse_mass_matrix=params.inverse_mass_matrix,
            )

sampler = thin_algorithm(
            sampler,
            thinning=16,
            info_transform=lambda info: tree.map(jnp.mean, info),
            )

state, history = run_inference_algorithm(
            rng_key=run_key,
            initial_state=state,
            inference_algorithm=sampler,
            num_steps=100,
            )
    
  • NB: I exposed the Lfactor=0.4 parameter in mclmc_find_L_and_step_size, because if thinning is too high, the computed ESS on the thinned samples would be bigger than on the non-thinned samples, leading to underestimating autocorrelation length and therefore L. This can simply be compensated by increasing Lfactor. In practice, during my tests, I only found minor changes (on L estimation, with vs. without thinning) for reasonable thinning values, and so I am not sure this option is necessary. Actually, one shouldn't perform thinning to the point that it deteriorates the ESS, since one could just make less sampling steps then.

@junpenglao
Copy link
Member

Overall LGTM, could you add some test?

@hsimonfroy
Copy link
Contributor Author

Ok, Python 3.12 test was passing a few weeks ago, smthg seems to have broken between pytest and xdist.

INTERNALERROR> pytest_benchmark.logger.PytestBenchmarkWarning: Benchmarks are automatically disabled because xdist plugin is active.Benchmarks cannot be performed reliably in a parallelized environment.

@hsimonfroy
Copy link
Contributor Author

Hello @junpenglao and @reubenharry,
I had to remove the -n auto option in pytest workflow to solve for the pytest-xdist breaking on main branch. I don't know if you found another workaround.

Then I added test as ThinInferenceAlgorithmTest following the same chex format than RunInferenceAlgorithmTest.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants