-
Couldn't load subscription status.
- Fork 121
Multi-Path Pathfinder #783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This update adds the multi pathfinder capabilities, parallel Pathfinder runs, LBFGS optimisation using Optax, and importance sampling. - Multi-Pathfinder: Supports parallel and vectorized sampling strategies, improving scalability. - LBFGS Optimizer: Refactored using Optax for better numerical stability and performance, with enhanced inverse Hessian estimation. - Importance Sampling: Added support for various methods, improving sampling accuracy. - Testing: Expanded tests for both Pathfinder and LBFGS to ensure correctness and stability.
| def psislw_wrapper(logiw_array): | ||
| def psislw(logiw_array): | ||
| result_logiw, result_k = az.psislw(np.array(logiw_array)) | ||
| return np.array(result_logiw), np.array(result_k) | ||
|
|
||
| return jax.pure_callback( | ||
| psislw, | ||
| (jnp.zeros_like(logiw_array), jnp.zeros((), dtype=jnp.float64)), | ||
| logiw_array, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a pure JAX implementation of psislw (in a seperate PR), as pure_callback will be quite inefficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can. I have found a JAX implementation here https://gist.github.com/adamhaber/0556671340e0daa9e2c6e3fd535cd992 by @adamhaber as a good starting point
|
Hi @aphc14, could you isolate the change replacing jaxopt with optax in a seperate PR? |
|
Should this PR include the optax code changes while referencing the separate optax PR? Or should this PR revert the changes from optax to jaxopt? |
|
Hi! I think this would be very handy! I have used the PyMC Extras multipathfinder and works very well. Let me know if I could support in anyway (eg mypy errors of optax refactor 🤗) |
|
thanks @juanitorduz, I'll try to get this over the line |
Implement Multi-Pathfinder and Enhance Pathfinder and LBFGS Optimizers
High-Level Description
Multi-Pathfinder: Introduces the
multi_pathfinderfunction, enabling parallel and vectorized sampling strategies. Initial tests indicate thatmulti_pathfinderis functioning as expected. Feel free to test it out and provide feedback.LBFGS Optimizer: The optimizer has been refactored to use Optax.
single pathfinder fix: Fixed inaccurate calculations of S Z matrices, phi, log densities.
alpha_recover: Decoupling of alpha recover out of LBFGS optimisation.
Importance Sampling: Added support for various importance sampling methods.
Testing: Expanded and updated tests for both Pathfinder and LBFGS.
Current Status
This is a draft PR that requires some tidying up. There are existing linting errors from mypy that need to be addressed. Your feedback and testing are welcome to help refine these changes.
Checklist
maincommit.pre-committo check for any issues.resolves #763, #213, #461, #749, #704
related #465, #387