Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Conversation

master
Copy link
Contributor

@master master commented May 29, 2023

Add full-batch Hamiltonian Monte Carlo implementation.

Pull request type

Please check the type of change your PR introduces:

  • Bugfix
  • Feature
  • Code style update (formatting, renaming)
  • Refactoring (no functional changes, no api changes)
  • Build related changes
  • Documentation content changes
  • Other (please describe):

momentum, _ = jax.flatten_util.ravel_pytree(momentum)
kinetic = 0.5 * jnp.dot(momentum, momentum)
hamiltonian = kinetic + state.log_prob
accept_prob = jnp.minimum(1.0, jnp.exp(hamiltonian - state.hamiltonian))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, you can avoid the minimum and the exponential here. You can define

log_accept_ratio = hamiltonian - state.hamiltonian

See later for the accept/reject part.

return revert_updates, state.params, state.hamiltonian

updates, new_params, new_hamiltonian = jax.lax.cond(
jax.random.uniform(uniform_key) < accept_prob,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following the comment above, this line should become

jnp.log(jax.random.uniform(uniform_key)) < log_accept_ratio.

This is equivalent to what you have written but with one operation less. Alternatively, notice that -log(U) ~ Exponential(1)) if U~Uniform(0, 1). This means that you can also write

-jax.random.exponential(uniform_key)) < log_accept_ratio.

All of these should be equivalent. Please check that the lines I wrote are correct :-)

"""

encoded_name: jnp.ndarray = convert_string_to_jnp_array("HMCState")
_encoded_which_params: Optional[Dict[str, List[Array]]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting to see the stored _hamiltonian here too?

**kwargs,
)
state = state.replace(
opt_state=state.opt_state._replace(log_prob=aux["loss"]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should opt_state be added to the parameters of HMCState?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants