GenJAX is a probabilistic programming language (PPL): a system which provides automation for writing programs which perform computations on probability distributions, including sampling, variational approximation, gradient estimation for expected values, and more.
The design of GenJAX is centered on programmable inference: automation which allows users to express and customize Bayesian inference algorithms (algorithms for computing with posterior distributions: "x affects y, and I observe y, what are my new beliefs about x?"). Programmable inference includes advanced forms of Monte Carlo and variational inference methods.
GenJAX's automation is based on two key concepts:
- Generative functions – GenJAX's version of probabilistic programs
- Traces – samples from probabilistic programs
GenJAX provides:
- Modeling language automation for constructing complex probability distributions from pieces
- Inference automation for constructing Monte Carlo samplers using convenient idioms (programs expressed by creating and editing traces), and variational inference automation using new extensions to automatic differentation for expected values
This repository is also a POPL'26 artifact submitted alongside the paper Probabilistic Programming with Vectorized Programmable Inference.
Canonical artifact version: v1.0.10 - Use this release for artifact evaluation.
It contains the GenJAX implementation (including source code and tests), extensive documentation, curated agentic context (see the AGENTS.md throughout the codebase) to allow users of Claude Code and Codex (or others) to quickly use the system, and several of the case studies used in the empirical evaluation.
Contents:
The snippets below develop the polynomial regression example from our paper's Overview section in GenJAX: we compose the model from generative functions, vectorize importance sampling to scale the number of particles, extend the model with stochastic branching to capture outliers, and finish with a programmable kernel that mixes enumerative Gibbs updates with Hamiltonian Monte Carlo. The full code can be found in examples/curvefit; the code here can be run as a linear notebook-style walkthrough.
We begin by expressing the polynomial regression model as a composition of generative functions (@gen-decorated Python functions).
Each random choice (invocation of a generative function) is tagged with a string address ("a", "b", "c", "obs"), which is used to construct a structured representation of the model’s random variables, called a trace.
In GenJAX, packaging the coefficients inside a callable Lambda Pytree is a convenient way to allow downstream computations to call the curve directly, while the trace retains access to its parameters.
from genjax import gen, normal
from genjax.core import Const
from genjax import Pytree
import jax.numpy as jnp
@Pytree.dataclass
class Lambda(Pytree):
# Wrap polynomial coefficients in a callable pytree so traces retain parameters.
f: Const[object]
dynamic_vals: jnp.ndarray
static_vals: Const[tuple] = Const(())
def __call__(self, *x):
return self.f.value(*x, *self.static_vals.value, self.dynamic_vals)
def polyfn(x, coeffs):
# Deterministic quadratic curve evaluated at x.
a, b, c = coeffs[0], coeffs[1], coeffs[2]
return a + b * x + c * x**2
@gen
def polynomial():
a = normal(0.0, 1.0) @ "a"
b = normal(0.0, 1.0) @ "b"
c = normal(0.0, 1.0) @ "c"
return Lambda(Const(polyfn), jnp.array([a, b, c]))
@gen
def point(x, curve):
y_det = curve(x)
y_obs = normal(y_det, 0.05) @ "obs"
return y_obs
@gen
def npoint_curve(xs):
curve = polynomial() @ "curve"
ys = point.vmap(in_axes=(0, None))(xs, curve) @ "ys"
return curve, (xs, ys)
xs = jnp.linspace(0.0, 1.0, 8)
trace = npoint_curve.simulate(xs)
print(trace.get_choices()["curve"].keys())
print(trace.get_choices()["ys"]["obs"].shape)Vectorizing the point generative function with vmap mirrors the Overview section's Figure 3: the resulting trace preserves the hierarchical structure of the coefficients of the polynomial while lifting the observation random choice into a vectorized array-valued version. This structure preserving vectorization is what later enables us to reason about datasets consisting of many points (and other inference logic) in a vectorized fashion.
The generative function interface supplies a small set of methods - simulate, generate, assess, update - that we can compose into inference algorithms.
Here, we implement likelihood weighting (importance sampling): a single-particle routine constrains the observation site given a fixed value via the generate interface, while a vectorized wrapper increases the number of particles. The logic of guessing (sampling) and checking (computing an importance weight) -- internally implemented in generate -- remains the same across particles, only the array dimensions vary with the particle count.
from jax.scipy.special import logsumexp
from genjax.pjax import modular_vmap
import jax.numpy as jnp
def single_particle_importance(model, xs, ys_obs):
# Draw a single constrained trace and compute its importance weight.
trace, log_weight = model.generate({"ys": {"obs": ys_obs}}, xs)
return trace, log_weight
def vectorized_importance_sampling(model, xs, ys_obs, num_particles):
# Lift the single-particle routine across an explicit particle axis.
sampler = modular_vmap(
single_particle_importance,
in_axes=(None, None, None),
axis_size=num_particles,
)
return sampler(model, xs, ys_obs)
def log_marginal_likelihood(log_weights):
return logsumexp(log_weights) - jnp.log(log_weights.shape[0])
xs = jnp.linspace(0.0, 1.0, 8)
trace = npoint_curve.simulate(xs)
_, (_, ys_obs) = trace.get_retval()
traces, log_weights = vectorized_importance_sampling(
npoint_curve, xs, ys_obs, num_particles=512
)
print(traces.get_choices()["curve"]["a"].shape, log_marginal_likelihood(log_weights))Running on a GPU allows us to increase the number of particles as far as memory allows, just as in the scaling curves shown in Figure 5 in the paper.
Modeling of real datasets often benefits from considering explanations of the data that include heterogeneous noise processes. Following the Overview, we enrich the observation model with stochastic branching that classifies each datapoint as an inlier or an outlier. The latent is_outlier switch feeds a Cond combinator that chooses between a tight Gaussian noise model and a broad uniform alternative; both branches write to the same observation address so later inference can target the entire ys subtree uniformly.
from genjax import Cond, flip, uniform
@gen
def inlier_branch(mean, extra_noise):
# Inlier observations stay near the quadratic trend.
return normal(mean, 0.1) @ "obs"
@gen
def outlier_branch(_, extra_noise):
# Outliers come from a broad, curve-independent distribution.
return uniform(-2.0, 2.0) @ "obs"
@gen
def point_with_outliers(x, curve, outlier_rate=0.1, extra_noise=5.0):
is_outlier = flip(outlier_rate) @ "is_outlier"
cond_model = Cond(outlier_branch, inlier_branch)
y_det = curve(x)
return cond_model(is_outlier, y_det, extra_noise) @ "y"
@gen
def npoint_curve_with_outliers(xs, outlier_rate=0.1):
curve = polynomial() @ "curve"
ys = point_with_outliers.vmap(
in_axes=(0, None, None, None)
)(xs, curve, outlier_rate, 5.0) @ "ys"
return curve, (xs, ys)
xs = jnp.linspace(0.0, 1.0, 8)
trace = npoint_curve_with_outliers.simulate(xs)
choices = trace.get_choices()["ys"]
print(choices["is_outlier"].shape, choices["y"]["obs"].shape)The resulting trace contains a boolean vector of outlier indicators alongside the observations, matching the mixture-structured traces shown in Figure 6.
To improve inference accuracy on the richer model we combine discrete and continuous updates within a single programmable kernel. Enumerative Gibbs updates each is_outlier choice by scoring the two possible values with assess before resampling, while Hamiltonian Monte Carlo refines the continuous parameters. Both steps operate on traces using a selection (a way to target addresses within a trace), and they compose sequentially without requiring special-case code.
import jax
import jax.numpy as jnp
from genjax.core import Const, sel
from genjax.distributions import categorical
from genjax.pjax import modular_vmap, seed
from genjax.inference import hmc, chain
def enumerative_gibbs_outliers(trace, xs, ys, outlier_rate=0.1):
curve_params = trace.get_choices()["curve"]
curve = Lambda(
Const(polyfn),
jnp.array([curve_params["a"], curve_params["b"], curve_params["c"]]),
)
def update_single_point(x, y_obs):
chm_false = {"is_outlier": False, "y": {"obs": y_obs}}
# Score the inlier explanation for the current observation.
log_false, _ = point_with_outliers.assess(chm_false, x, curve, outlier_rate, 5.0)
chm_true = {"is_outlier": True, "y": {"obs": y_obs}}
# Score the outlier explanation for the same observation.
log_true, _ = point_with_outliers.assess(chm_true, x, curve, outlier_rate, 5.0)
logits = jnp.array([log_false, log_true])
return categorical.sample(logits=logits) == 1
new_outliers = modular_vmap(update_single_point)(xs, ys)
gen_fn = trace.get_gen_fn()
args = trace.get_args()
new_trace, _, _ = gen_fn.update(
trace, {"ys": {"is_outlier": new_outliers}}, *args[0], **args[1]
)
return new_trace
def mixed_gibbs_hmc_kernel(xs, ys, hmc_step_size=0.01, hmc_n_steps=10, outlier_rate=0.1):
def kernel(trace):
trace = enumerative_gibbs_outliers(trace, xs, ys, outlier_rate)
return hmc(
trace,
sel(("curve", "a")) | sel(("curve", "b")) | sel(("curve", "c")),
step_size=hmc_step_size,
n_steps=hmc_n_steps,
)
return kernel
xs = jnp.linspace(0.0, 1.0, 8)
trace = npoint_curve_with_outliers.simulate(xs)
_, (_, ys) = trace.get_retval()
constraints = {"ys": {"y": {"obs": ys}}}
initial_trace, _ = npoint_curve_with_outliers.generate(constraints, xs, 0.1)
kernel = mixed_gibbs_hmc_kernel(xs, ys, hmc_step_size=0.02, hmc_n_steps=5)
runner = chain(kernel)
samples = seed(runner)(jax.random.key(0), initial_trace, n_steps=Const(10))
print(samples.traces.get_choices()["curve"]["a"].shape)These moves yields a chain that captures both inlier/outlier classifications and posterior uncertainty over polynomial coefficients.
For the full treatment—including command-line tooling, figure generation, and alternative inference routines—see examples/curvefit.
Install pixi (package manager, which will allow you to build and run the case studies).
cd genjax
pixi installThis creates isolated conda environments for each case study.
Note: all figures are reproducible from this repository, including the multi-system benchmarking panel (Fig 16b). That benchmark now lives under examples/perfbench; plan for ~5–10 minutes per run whether you target CPU or CUDA.
To generate several of the case study paper figures, use the following commands:
# CPU execution
pixi run paper-figures
# GPU execution (requires CUDA 12)
pixi run paper-figures-gpuAll figures are saved to genjax/figs/.
Here's a list of behaviors which should be expected when running the case studies on CPU (see also Case-study device expectations for per-study notes):
- CPU takes longer than GPU: executed on an Apple M4 (Macbook Air) takes around 4 minutes.
- CPU won't exhibit the same vectorized scaling properties as GPU (in many cases, linear versus near-constant scaling).
- When running the artifact on CPU only, some of the timing figures may be missing comparisons between CPU and GPU.
Keep these behaviors in mind when interpreting figures generated via CPU execution.
Different case studies stress different JAX vectorization patterns, so device characteristics matter:
- Fair Coin – entirely CPU-friendly; GPU only affects wall-clock time, not the outputs.
- Curvefit – scaling/timing plots assume a GPU so that the importance sampler can sweep up to 1M particles; on CPU you can pass smaller
--scaling-*flags (see below) and expect longer runtimes plus different timing curves. - Game of Life – Gibbs timing bars compare CPU vs GPU throughput; animations work everywhere but the scaling panel only matches the paper if both device types are available.
- Localization – Figure 19 relies on access to a CUDA environment (
pixi run -e localization-cuda …). The SMC+HMC rejuvenation loop and vectorised LIDAR beams batch over large arrays; on CPU we drop the computational load to much smaller grids/particle counts, so the four-panel comparison generated via the CPU command will differ. If you have access to a GPU, use the GPU command to match the paper. - Performance benchmark – the
examples/perfbenchcase study (Figure 16b) spins up separate Pixi environments for NumPyro, TensorFlow Probability, Pyro, hand-coded PyTorch, and Gen.jl. CUDA is required for the default Pyro/TFP runs (use--device cpuwhen invoking the scripts to fall back to CPU). Budget roughly 5–10 minutes for the full sweep on either CPU or GPU; trim the particle/chain grids or repeats if you only need a smoke test.
We expect that any environment which supports JAX should allow you to run our artifact (using pixi) -- but for precision, here's a list of devices which we tested the artifact on (using pixi run paper-figures for CPU and pixi run paper-figures-gpu for GPU):
Apple M4 (Macbook Air)
- Model: MacBook Air (Mac16,12)
- Chip: Apple M4
- CPU Cores: 10 cores total (4 performance + 6 efficiency)
- Memory: 16 GB
- OS: macOS 15.6 (Sequoia)
- Build: 24G84
- Kernel: Darwin 24.6.0
Linux machine with Nvidia RTX 4090
- Pop!_OS 22.04 LTS
- Kernel: Linux 6.16.3-76061603-generic
- AMD Ryzen 7 7800X3D 8-Core Processor
- 16 threads (8 cores with SMT)
- Max frequency: 5.05 GHz
- 96 MiB L3 cache
- NVIDIA GeForce RTX 4090 (24GB VRAM)
In this section, we provide more details on the case studies. In each case study, we provide a reference to the figures in the paper which the case study supports.
What it does: Compares GenJAX, handcoded JAX, and NumPyro on a simple inference problem (with known exact inference).
Figures in the paper: Figure 16 (a).
Command:
pixi run -e faircoin python -m examples.faircoin.main \
--combined --num-obs 50 --num-samples 2000 --repeats 10Outputs: figs/faircoin_combined_posterior_and_timing_obs50_samples2000.pdf
What it does: Polynomial regression with robust outlier detection, demonstrating:
- GPU scaling of importance sampling with varying particle counts
- Gibbs sampling with HMC for an outlier mixture model
Figures in the paper: Figure 4, Figure 5, Figure 6.
Command (CPU, full sweep):
pixi run paper-curvefit-genOutputs: 5 figures in figs/:
curvefit_prior_multipoint_traces_density.pdfcurvefit_single_multipoint_trace_density.pdfcurvefit_scaling_performance.pdfcurvefit_posterior_scaling_combined.pdfcurvefit_outlier_detection_comparison.pdf
Low-resource variant (pass custom flags):
pixi run paper-curvefit-custom -- --scaling-max-samples 20000 --scaling-trials 2You can still run the raw Python entry point directly:
pixi run -e curvefit python -m examples.curvefit.main paper \
--scaling-max-samples 20000 --scaling-trials 2For CUDA, use pixi run paper-curvefit-gpu-gen (full sweep) or
pixi run paper-curvefit-gpu-custom -- --scaling-max-samples 20000 --scaling-trials 2.
These variants trim the GPU scaling benchmark to particle counts ≤20k (or supply --scaling-particle-counts for a custom grid).
Julia requirement: The Gen.jl baselines need Julia ≥1.10. We recommend installing via juliaup so
juliais on your PATH (e.g.curl -fsSL https://install.julialang.org | sh).
What it does: Repackages the full timing-benchmarks project (commit d4433b0) that produces Figure 16(b): GenJAX vs. NumPyro, Pyro, TensorFlow Probability, hand-coded PyTorch, and Gen.jl on both importance sampling and HMC for the polynomial regression task.
Where it lives: examples/perfbench remains its own Pixi project, but the repo-level task pixi run paper-perfbench … shells into python examples/perfbench/main.py pipeline so you can supply all flags from the root.
# CPU run (default; data -> examples/perfbench/data_cpu, figs -> figs_cpu)
pixi run paper-perfbench
# CUDA run (JAX frameworks get JAX_PLATFORMS=cuda, Pyro/Torch receive --device cuda)
pixi run paper-perfbench --mode cuda
# Stage selection
pixi run paper-perfbench --inference is # Importance sampling only
pixi run paper-perfbench --inference hmc # HMC only
# Framework selection (applies to both stages unless overridden)
pixi run paper-perfbench --frameworks genjax numpyro handcoded_jax
# Stage-specific overrides
pixi run paper-perfbench --is-frameworks genjax numpyro --hmc-frameworks genjax numpyro pyroKey pipeline flags:
--particlescontrols the IS particle grid (default 1k/5k/10k).--is-repeats/--is-inner-repeatsdefault to 50×50; GenJAX, NumPyro, and hand-coded JAX automatically bump to 100×100 unless you pass explicit values.--hmc-chain-lengths,--hmc-repeats,--hmc-warmup,--hmc-step-size,--hmc-n-leapfrogfeed the shared HMC runner; Gen.jl reuses the same knobs viagenjl-hmc.--skip-generate,--skip-is,--skip-hmc,--skip-plots,--skip-exportlet you resume partially completed runs without redoing work. Plotting/export only happens when fresh data exists.
Example (HMC-only CUDA sweep with trimmed repeats and no plotting/export):
pixi run paper-perfbench \
--mode cuda \
--frameworks genjax numpyro \
--inference hmc \
--hmc-chain-lengths 1000 5000 \
--hmc-repeats 25 \
--skip-plots --skip-exportImplementation details:
- The pipeline automatically selects the right Pixi environment per framework (
cudafor TensorFlow Probability + JAX HMC,pyrofor Pyro,torchfor the hand-coded PyTorch kernels, default for everything else) and injectsJAX_PLATFORMSso CUDA/CPU runs are explicit. - HMC timings go through
benchmarks/run_hmc_benchmarks.py. Pyro, hand-coded PyTorch, and Gen.jl are intentionally capped at 5 outer × 5 inner repeats inside that runner so the full sweep stays within a ~5–10 minute window; override them by running the helper script directly if you need denser statistics. - Results land in
examples/perfbench/data/curvefit/<framework>/andexamples/perfbench/figs/…for CUDA mode (CPU mode mirrors them underdata_cpu/andfigs_cpu/). Gen.jl artifacts show up once its Julia project auto-instantiates (julia --project=examples/perfbench/benchmarks/julia -e 'using Pkg; Pkg.instantiate(); Pkg.precompile()').
Once both stages finish, the pipeline merges all available JSON into figs_{cpu,}/benchmark_timings_{is,hmc}_all_frameworks.{pdf,png} plus benchmark_summary_*.csv and benchmark_table.tex, then copies PDFs into the repo-level figs/ directory if --skip-export is not set. See examples/perfbench/AGENTS.md for the full CLI reference.
Julia requirement: The Gen.jl baselines need a Julia toolchain (≥1.10). We recommend installing via juliaup so that
juliais on yourPATH(e.g.,curl -fsSL https://install.julialang.org | sh). The first time the pipeline touches Gen.jl it will automatically runjulia --project=examples/perfbench/benchmarks/julia -e 'using Pkg; Pkg.instantiate(); Pkg.precompile()'.
What it does: Infers past Game of Life states from observed future states using Gibbs sampling on a 512×512 grid with 250 Gibbs steps.
Figures in the paper: Figure 18.
Command:
pixi run -e gol gol-paperOutputs: 2 figures in figs/:
gol_integrated_showcase_wizards_512.pdf(3-panel inference showcase)gol_gibbs_timing_bar_plot.pdf(performance across grid sizes)
Note: Timing bar plot runs benchmarks at 64×64, 128×128, 256×256, and 512×512 grid sizes.
What it does: Particle filter localization comparing bootstrap filter, SMC+HMC, and approximate (using grid enumeration) locally optimal proposals with 200 particles and a generative model with a simulated 8-ray LIDAR measurement.
Figures in the paper: Figure 19.
Command (GPU, matches paper):
pixi run -e localization-cuda python -m examples.localization.main paper \
--include-basic-demo --include-smc-comparison \
--n-particles 200 --n-steps 8 --timing-repeats 3 --n-rays 8 --output-dir figsRunning without CUDA (pixi run -e localization …) executes the same probabilistic program, but the SMC benchmark scales back the vectorised LIDAR beams and rejuvenation sweeps to keep the CPU runtime manageable, so timing/ESS panels will no longer match the GPU figure shown in the paper (see “Case-study device expectations” above for details).
Historical CPU command (slower, produces reduced-resolution figures):
pixi run -e localization python -m examples.localization.main paper \
--include-basic-demo --include-smc-comparison \
--n-particles 200 --n-steps 8 --timing-repeats 3 --n-rays 8 --output-dir figsOutputs: 2 figures in figs/:
localization_r8_p200_basic_localization_problem_1x4_explanation.pdflocalization_r8_p200_basic_comprehensive_4panel_smc_methods_analysis.pdf
Note: Also generates experimental data in examples/localization/data/ (regenerated each run).
All figures are saved to figs/:
faircoin_combined_posterior_and_timing_obs50_samples2000.pdf- Framework comparison (timing + posterior accuracy)
curvefit_prior_multipoint_traces_density.pdf- Prior samples from generative modelcurvefit_single_multipoint_trace_density.pdf- Single trace with log densitycurvefit_scaling_performance.pdf- Inference scaling with particle countcurvefit_posterior_scaling_combined.pdf- Posterior quality at different scalescurvefit_outlier_detection_comparison.pdf- Robust inference with mixture modelcurvefit_vectorization_illustration.pdf- Static diagram (already in repo)
gol_integrated_showcase_wizards_512.pdf- Inverse dynamics inference (3 panels)gol_gibbs_timing_bar_plot.pdf- Performance scaling across grid sizes
localization_r8_p200_basic_localization_problem_1x4_explanation.pdf- Problem setuplocalization_r8_p200_basic_comprehensive_4panel_smc_methods_analysis.pdf- SMC method comparison
perfbench_benchmark_timings_is_all_frameworks.pdf/perfbench_benchmark_timings_hmc_all_frameworks.pdf– CUDA sweep (IS + HMC across all frameworks)perfbench_cpu_benchmark_timings_is_all_frameworks.pdf/perfbench_cpu_benchmark_timings_hmc_all_frameworks.pdf– CPU-only rerun for environments without GPUs
