Skip to content

Conversation

@aphc14
Copy link

@aphc14 aphc14 commented Mar 8, 2025

Implement Multi-Pathfinder and Enhance Pathfinder and LBFGS Optimizers

High-Level Description

  • Multi-Pathfinder: Introduces the multi_pathfinder function, enabling parallel and vectorized sampling strategies. Initial tests indicate that multi_pathfinder is functioning as expected. Feel free to test it out and provide feedback.

  • LBFGS Optimizer: The optimizer has been refactored to use Optax.

  • single pathfinder fix: Fixed inaccurate calculations of S Z matrices, phi, log densities.

  • alpha_recover: Decoupling of alpha recover out of LBFGS optimisation.

  • Importance Sampling: Added support for various importance sampling methods.

  • Testing: Expanded and updated tests for both Pathfinder and LBFGS.

Current Status

This is a draft PR that requires some tidying up. There are existing linting errors from mypy that need to be addressed. Your feedback and testing are welcome to help refine these changes.

Checklist

  • Ensure the PR title clearly describes the changes.
  • Provide links to all relevant issues, discussions, and PRs.
  • Rebase the branch on the latest main commit.
  • Ensure commit messages follow the guidelines.
  • Verify that the code respects current naming conventions.
  • Ensure docstrings follow the numpy style guide.
  • Run pre-commit to check for any issues.
  • Confirm that there are tests covering the changes.
  • Update the documentation to reflect the new changes.
  • If applicable, add or update related examples.

resolves #763, #213, #461, #749, #704

related #465, #387

This update adds the multi pathfinder capabilities, parallel Pathfinder runs, LBFGS optimisation using Optax, and importance sampling.

- Multi-Pathfinder: Supports parallel and vectorized sampling strategies, improving scalability.
- LBFGS Optimizer: Refactored using Optax for better numerical stability and performance, with enhanced inverse Hessian estimation.
- Importance Sampling: Added support for various methods, improving sampling accuracy.
- Testing: Expanded tests for both Pathfinder and LBFGS to ensure correctness and stability.
Comment on lines +487 to +496
def psislw_wrapper(logiw_array):
def psislw(logiw_array):
result_logiw, result_k = az.psislw(np.array(logiw_array))
return np.array(result_logiw), np.array(result_k)

return jax.pure_callback(
psislw,
(jnp.zeros_like(logiw_array), jnp.zeros((), dtype=jnp.float64)),
logiw_array,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a pure JAX implementation of psislw (in a seperate PR), as pure_callback will be quite inefficient.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can. I have found a JAX implementation here https://gist.github.com/adamhaber/0556671340e0daa9e2c6e3fd535cd992 by @adamhaber as a good starting point

@junpenglao
Copy link
Member

Hi @aphc14, could you isolate the change replacing jaxopt with optax in a seperate PR?

@aphc14
Copy link
Author

aphc14 commented May 18, 2025

Should this PR include the optax code changes while referencing the separate optax PR? Or should this PR revert the changes from optax to jaxopt?

@juanitorduz
Copy link

Hi! I think this would be very handy! I have used the PyMC Extras multipathfinder and works very well. Let me know if I could support in anyway (eg mypy errors of optax refactor 🤗)

@aphc14
Copy link
Author

aphc14 commented Jul 26, 2025

thanks @juanitorduz, I'll try to get this over the line

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Multi-path pathfinder instead of just one-path pathfinder implementation

3 participants