This repository was archived by the owner on Apr 23, 2025. It is now read-only.
generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 52
add vanilla HMC method #75
Draft
master
wants to merge
6
commits into
awslabs:main
Choose a base branch
from
master:hmc
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
9554a25
fix SGMCMC state repo
master 10a9527
add full-batch HMC implementation
master a96283a
update documentation
master cf2c710
lint the code
master 90974f4
Fix missing import
master c1973e1
Fix optimizer freeze support in HMC
master File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
HMC_NAME = "hmc" |
56 changes: 56 additions & 0 deletions
56
fortuna/prob_model/posterior/sgmcmc/hmc/hmc_approximator.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from typing import Union | ||
|
||
from fortuna.prob_model.posterior.sgmcmc.base import ( | ||
SGMCMCPosteriorApproximator, | ||
) | ||
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_preconditioner import ( | ||
Preconditioner, | ||
identity_preconditioner, | ||
) | ||
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule import ( | ||
StepSchedule, | ||
constant_schedule, | ||
) | ||
from fortuna.prob_model.posterior.sgmcmc.hmc import HMC_NAME | ||
|
||
|
||
class HMCPosteriorApproximator(SGMCMCPosteriorApproximator): | ||
def __init__( | ||
self, | ||
n_samples: int = 10, | ||
n_thinning: int = 1, | ||
burnin_length: int = 1000, | ||
integration_steps: int = 50_000, | ||
step_schedule: Union[StepSchedule, float] = 3e-5, | ||
) -> None: | ||
""" | ||
HMC posterior approximator. It is responsible to define how the posterior distribution is approximated. | ||
|
||
Parameters | ||
---------- | ||
n_samples: int | ||
The desired number of the posterior samples. | ||
n_thinning: int | ||
If `n_thinning` > 1, keep only each `n_thinning` sample during the sampling phase. | ||
burnin_length: int | ||
Length of the initial burn-in phase, in steps. | ||
integration_steps: int | ||
Number of integration steps per trajectory. | ||
step_schedule: Union[StepSchedule, float] | ||
Either a constant `float` step size or a schedule function. | ||
|
||
""" | ||
super().__init__( | ||
n_samples=n_samples, | ||
n_thinning=n_thinning, | ||
) | ||
if isinstance(step_schedule, float): | ||
step_schedule = constant_schedule(step_schedule) | ||
elif not callable(step_schedule): | ||
raise ValueError(f"`step_schedule` must be a a callable function.") | ||
self.burnin_length = burnin_length | ||
self.integration_steps = integration_steps | ||
self.step_schedule = step_schedule | ||
|
||
def __str__(self) -> str: | ||
return HMC_NAME |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from typing import Optional | ||
|
||
from fortuna.training.train_state import TrainState | ||
from fortuna.training.callback import Callback | ||
from fortuna.training.train_state_repository import TrainStateRepository | ||
from fortuna.training.trainer import TrainerABC | ||
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_sampling_callback import ( | ||
SGMCMCSamplingCallback, | ||
) | ||
|
||
|
||
class HMCSamplingCallback(SGMCMCSamplingCallback): | ||
def __init__( | ||
self, | ||
n_epochs: int, | ||
n_training_steps: int, | ||
n_samples: int, | ||
n_thinning: int, | ||
burnin_length: int, | ||
trainer: TrainerABC, | ||
state_repository: TrainStateRepository, | ||
keep_top_n_checkpoints: int, | ||
): | ||
""" | ||
Hamiltonian Monte Carlo (HMC) callback that collects samples after the initial burn-in phase. | ||
|
||
Parameters | ||
---------- | ||
n_epochs: int | ||
The number of epochs. | ||
n_training_steps: int | ||
The number of steps per epoch. | ||
n_samples: int | ||
The desired number of the posterior samples. | ||
n_thinning: int | ||
Keep only each `n_thinning` sample during the sampling phase. | ||
burnin_length: int | ||
Length of the initial burn-in phase, in steps. | ||
trainer: TrainerABC | ||
An instance of the trainer class. | ||
state_repository: TrainStateRepository | ||
An instance of the state repository. | ||
keep_top_n_checkpoints: int | ||
Number of past checkpoint files to keep. | ||
""" | ||
super().__init__( | ||
trainer=trainer, | ||
state_repository=state_repository, | ||
keep_top_n_checkpoints=keep_top_n_checkpoints, | ||
) | ||
|
||
self._do_sample = ( | ||
lambda current_step, samples_count: samples_count < n_samples | ||
and current_step > burnin_length | ||
and (current_step - burnin_length) % n_thinning == 0 | ||
) | ||
|
||
total_samples = sum( | ||
self._do_sample(step, 0) | ||
for step in range(1, n_epochs * n_training_steps + 1) | ||
) | ||
if total_samples < n_samples: | ||
raise ValueError( | ||
f"The number of desired samples `n_samples` is {n_samples}. However, only " | ||
f"{total_samples} samples will be collected. Consider adjusting the burnin " | ||
"length, number of epochs, or the thinning parameter." | ||
) |
129 changes: 129 additions & 0 deletions
129
fortuna/prob_model/posterior/sgmcmc/hmc/hmc_integrator.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
from fortuna.typing import Array | ||
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule import ( | ||
StepSchedule, | ||
) | ||
from fortuna.utils.random import generate_random_normal_like_tree | ||
from jax._src.prng import PRNGKeyArray | ||
from optax._src.base import PyTree | ||
from optax import GradientTransformation | ||
from typing import NamedTuple | ||
|
||
|
||
class OptaxHMCState(NamedTuple): | ||
"""Optax state for the HMC integrator.""" | ||
|
||
count: Array | ||
rng_key: PRNGKeyArray | ||
momentum: PyTree | ||
params: PyTree | ||
hamiltonian: Array | ||
log_prob: Array | ||
|
||
|
||
def hmc_integrator( | ||
integration_steps: int, | ||
rng_key: PRNGKeyArray, | ||
step_schedule: StepSchedule, | ||
) -> GradientTransformation: | ||
"""Optax implementation of the HMC integrator. | ||
|
||
Parameters | ||
---------- | ||
integration_steps: int | ||
Number of leapfrog integration steps in each trajectory. | ||
rng_key: PRNGKeyArray | ||
An initial random number generator. | ||
step_schedule: StepSchedule | ||
A function that takes training step as input and returns the step size. | ||
""" | ||
|
||
def init_fn(params): | ||
return OptaxHMCState( | ||
count=jnp.zeros([], jnp.int32), | ||
rng_key=rng_key, | ||
momentum=jax.tree_util.tree_map(jnp.zeros_like, params), | ||
params=params, | ||
hamiltonian=jnp.array(-1e6, jnp.float32), | ||
log_prob=jnp.zeros([], jnp.float32), | ||
) | ||
|
||
def update_fn(gradient, state, params): | ||
step_size = step_schedule(state.count) | ||
|
||
def leapfrog_step(): | ||
updates = jax.tree_map( | ||
lambda m: m * step_size, | ||
state.momentum, | ||
) | ||
momentum = jax.tree_map( | ||
lambda m, g: m + g * step_size, | ||
state.momentum, | ||
gradient, | ||
) | ||
return updates, OptaxHMCState( | ||
count=state.count + 1, | ||
rng_key=state.rng_key, | ||
momentum=momentum, | ||
params=state.params, | ||
hamiltonian=state.hamiltonian, | ||
log_prob=state.log_prob, | ||
) | ||
|
||
def mh_correction(): | ||
key, new_key, uniform_key = jax.random.split(state.rng_key, 3) | ||
|
||
momentum = jax.tree_map( | ||
lambda m, g: m + g * step_size / 2, | ||
state.momentum, | ||
gradient, | ||
) | ||
|
||
momentum, _ = jax.flatten_util.ravel_pytree(momentum) | ||
kinetic = 0.5 * jnp.dot(momentum, momentum) | ||
hamiltonian = kinetic + state.log_prob | ||
accept_prob = jnp.minimum(1.0, jnp.exp(hamiltonian - state.hamiltonian)) | ||
|
||
def _accept(): | ||
empty_updates = jax.tree_util.tree_map(jnp.zeros_like, params) | ||
return empty_updates, params, hamiltonian | ||
|
||
def _reject(): | ||
revert_updates = jax.tree_util.tree_map( | ||
lambda sp, p: sp - p, | ||
state.params, | ||
params, | ||
) | ||
return revert_updates, state.params, state.hamiltonian | ||
|
||
updates, new_params, new_hamiltonian = jax.lax.cond( | ||
jax.random.uniform(uniform_key) < accept_prob, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Following the comment above, this line should become
This is equivalent to what you have written but with one operation less. Alternatively, notice that
All of these should be equivalent. Please check that the lines I wrote are correct :-) |
||
_accept, | ||
_reject, | ||
) | ||
|
||
new_momentum = generate_random_normal_like_tree(key, gradient) | ||
new_momentum = jax.tree_map( | ||
lambda m, g: m + g * step_size / 2, | ||
new_momentum, | ||
gradient, | ||
) | ||
|
||
return updates, OptaxHMCState( | ||
count=state.count + 1, | ||
rng_key=new_key, | ||
momentum=new_momentum, | ||
params=new_params, | ||
hamiltonian=new_hamiltonian, | ||
log_prob=state.log_prob, | ||
) | ||
|
||
return jax.lax.cond( | ||
state.count % integration_steps == 0, | ||
mh_correction, | ||
leapfrog_step, | ||
) | ||
|
||
return GradientTransformation(init_fn, update_fn) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed, you can avoid the minimum and the exponential here. You can define
log_accept_ratio = hamiltonian - state.hamiltonian
See later for the accept/reject part.