Skip to content

probcomp/genjax

Repository files navigation

This is the research version of GenJAX, a (more stable) community version can be found here.

(Probabilistic programming language) GenJAX is a probabilistic programming language (PPL): a system which provides automation for writing programs which perform computations on probability distributions, including sampling, variational approximation, gradient estimation for expected values, and more.

(With programmable inference) The design of GenJAX is centered on programmable inference: automation which allows users to express and customize Bayesian inference algorithms (algorithms for computing with posterior distributions: "x affects y, and I observe y, what are my new beliefs about x?"). Programmable inference includes advanced forms of Monte Carlo and variational inference methods.

GenJAX's automation is based on two key concepts: generative functions (GenJAX's version of probabilistic programs) and traces (samples from probabilistic programs). GenJAX provides:

(Fully vectorized & compatible with JAX) All of GenJAX's automation is fully compatible with JAX, implying that any program written in GenJAX can be vmap'd and jit compiled.

Releases

No releases published

Packages

No packages published

Languages