Note: This is the research version of GenJAX. A (more stable) community version can be found here.
GenJAX is a probabilistic programming language (PPL): a system which provides automation for writing programs which perform computations on probability distributions, including sampling, variational approximation, gradient estimation for expected values, and more.
The design of GenJAX is centered on programmable inference: automation which allows users to express and customize Bayesian inference algorithms (algorithms for computing with posterior distributions: "x affects y, and I observe y, what are my new beliefs about x?"). Programmable inference includes advanced forms of Monte Carlo and variational inference methods.
GenJAX's automation is based on two key concepts:
- Generative functions – GenJAX's version of probabilistic programs
- Traces – samples from probabilistic programs
GenJAX provides:
- Modeling language automation for constructing complex probability distributions from pieces
- Inference automation for constructing Monte Carlo samplers using convenient idioms (programs expressed by creating and editing traces), and variational inference automation using new extensions to automatic differentation for expected values
All of GenJAX's automation is fully compatible with JAX, implying that any program written in GenJAX can be vmap
'd and jit
compiled.
This repository is optimized for development with Claude Code, Anthropic's AI coding assistant. The codebase includes comprehensive CLAUDE.md
files that provide context and guidance for Claude Code to work effectively with GenJAX.