This is the research version of GenJAX, a (more stable) community version can be found here.
(Probabilistic programming language) GenJAX is a probabilistic programming language (PPL): a system which provides automation for writing programs which perform computations on probability distributions, including sampling, variational approximation, gradient estimation for expected values, and more.
(With programmable inference) The design of GenJAX is centered on programmable inference: automation which allows users to express and customize Bayesian inference algorithms (algorithms for computing with posterior distributions: "x affects y, and I observe y, what are my new beliefs about x?"). Programmable inference includes advanced forms of Monte Carlo and variational inference methods.
GenJAX's automation is based on two key concepts: generative functions (GenJAX's version of probabilistic programs) and traces (samples from probabilistic programs). GenJAX provides:
- Modeling language automation for constructing complex probability distributions from pieces
- Inference automation for constructing Monte Carlo samplers using convenient idioms (programs expressed by creating and editing traces), and variational inference automation using new extensions to automatic differentation for expected values.
(Fully vectorized & compatible with JAX) All of GenJAX's automation is fully compatible with JAX, implying that any program written in GenJAX can be vmap
'd and jit
compiled.