Skip to content

Pseudo-Observation Parametrisations #121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ApproximateGPs"
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
authors = ["JuliaGaussianProcesses Team"]
version = "0.3.4"
version = "0.3.5"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand Down
28 changes: 3 additions & 25 deletions docs/src/userguide.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,31 +46,9 @@ The approximate posterior constructed above will be a very poor approximation, s
```julia
elbo(SparseVariationalApproximation(fz, q), fx, y)
```
A detailed example of how to carry out such optimisation is given in [Regression: Sparse Variational Gaussian Process for Stochastic Optimisation with Flux.jl](@ref). For an example of non-conjugate inference, see [Classification: Sparse Variational Approximation for Non-Conjugate Likelihoods with Optim's L-BFGS](@ref).

# Available Parametrizations

Two parametrizations of `q(u)` are presently available: [`Centered`](@ref) and [`NonCentered`](@ref).
The `Centered` parametrization expresses `q(u)` directly in terms of its mean and covariance.
The `NonCentered` parametrization instead parametrizes the mean and covariance of
`ε := cholesky(cov(u)).U' \ (u - mean(u))`.
These parametrizations are also known respectively as "Unwhitened" and "Whitened".

The choice of parametrization can have a substantial impact on the time it takes for ELBO
optimization to converge, and which parametrization is better in a particular situation is
not generally obvious.
That being said, the `NonCentered` parametrization often converges in fewer iterations, so it is the default --
it is what is used in all of the examples above.

If you require a particular parametrization, simply use the 3-argument version of the
approximation constructor:
```julia
SparseVariationalApproximation(Centered(), fz, q)
SparseVariationalApproximation(NonCentered(), fz, q)
```

For a general discussion around these two parametrizations, see e.g. [^Gorinova].
For a GP-specific discussion, see e.g. section 3.4 of [^Paciorek].

[^Gorinova]: Gorinova, Maria and Moore, Dave and Hoffman, Matthew [Automatic Reparameterisation of Probabilistic Programs](http://proceedings.mlr.press/v119/gorinova20a)
[^Paciorek]: [Paciorek, Christopher Joseph. Nonstationary Gaussian processes for regression and spatial modelling. Diss. Carnegie Mellon University, 2003.](https://www.stat.berkeley.edu/~paciorek/diss/paciorek-thesis.pdf)
There are various ways to parametrise the approximate posterior.
See [The Various Pseudo-Point Approximation Parametrisations](@ref) for more info and
worked examples.
17 changes: 17 additions & 0 deletions examples/d-sparse-parametrisations/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
ApproximateGPs = "298c2ebc-0411-48ad-af38-99e88101b606"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DrWatson = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
171 changes: 171 additions & 0 deletions examples/d-sparse-parametrisations/script.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# # The Various Pseudo-Point Approximation Parametrisations
#
# ### Note to the reader
# At the time of writing (March 2021) the best way to parametrise the approximate posterior
# remains a surprisingly active area of research.
# If you are reading this and feel that it has become outdated, or was incorrect in the
# first instance, it would be greatly appreciated if you could open an issue to discuss.
#
#
# ## Introduction
#
# This example examines the various ways in which this package supports parametrising the
# approximate posterior when utilising sparse approximations.
#
# All sparse (a.k.a. pseudo-point) approximations in this package utilise an approximate
# posterior over a GP ``f`` of the form
# ```math
# q(f) = q(\mathbf{u}) \, p(f_{\neq \mathbf{u}} | \mathbf{u})
# ```
# where samples from ``f`` are functions mapping ``\mathcal{X} \to \mathbb{R}``,
# ``\mathbf{u} := f(\mathbf{z})``, ``\mathbf{z} \in \mathcal{X}^M`` are the pseudo-inputs,
# and ``f_{\neq \mathbf{u}}`` denotes ``f`` at all indices other than those in
# ``\mathbf{z}``.[^Titsias]
# ``\mathbf{u} := q(f(\mathbf{z}))`` is generally restricted to be a multivariate Gaussian, to which end ApproximateGPs presently offers four parametrisations:
# 1. Centered ("Unwhitened"): ``q(\mathbf{u}) = \mathcal{N}(\mathbf{m}, \mathbf{C})``, ``\quad \mathbf{m} \in \mathbb{R}^M`` and positive-definite ``\mathbf{C} \in \mathbb{R}^{M \times M}``,
# 1. Non-Centered ("Whitened"): ``q(\mathbf{u}) = \mathcal{N}(\mathbf{L} \mathbf{m}, \mathbf{L} \mathbf{C} \mathbf{T}^\top)``, ``\quad \mathbf{L} \mathbf{L}^\top = \text{cov}(\mathbf{u})``,
# 1. Pseudo-Observation: ``q(\mathbf{u}) \propto p(\mathbf{u}) \, \mathcal{N}(\hat{\mathbf{y}}; \mathbf{u}, \hat{\mathbf{S}})``, ``\quad \hat{\mathbf{y}} \in \mathbb{R}^M`` and positive-definite ``\hat{\mathbf{S}} \in \mathbb{R}^{M \times M}``,
# 1. Decoupled Pseudo-Observation: ``q(\mathbf{u}) \propto p(\mathbf{u}) \, \mathcal{N}(\hat{\mathbf{y}}; f(\mathbf{v}), \hat{\mathbf{S}})``, ``\quad \hat{\mathbf{y}} \in \mathbb{R}^R``, ``\hat{\mathbf{S}} \in \mathbb{R}^{R \times R}`` is positive-definite and diagonal, and ``\mathbf{v} \in \mathcal{X}^R``.
#
# The choice of parametrization can have a substantial impact on the time it takes for ELBO
# optimization to converge, and which parametrization is better in a particular situation is
# not generally obvious.
# That being said, the `NonCentered` parametrization often converges in fewer iterations
# than the `Centered`, and is widely used, so it is the default.
#
# For a general discussion around the centered vs non-centered, see e.g. [^Gorinova].
# For a GP-specific discussion, see e.g. section 3.4 of [^Paciorek].

# ## Setup

using AbstractGPs
using ApproximateGPs
using CairoMakie
using Distributions
using Images
using KernelFunctions
using LinearAlgebra
using Optim
using Random
using Zygote

# A simple GP with inputs on the reals.
f = GP(SEKernel());
N = 100;
x = range(-3.0, 3.0; length=N);

# Generate some observations.
Σ = Diagonal(fill(0.1, N));
y = rand(Xoshiro(123456), f(x, Σ));

# Use a handful of pseudo-points.
M = 10;
z = range(-3.5, 3.5; length=M);

# Other misc. constants that we'll need later:
x_pred = range(-5.0, 5.0; length=300);
jitter = 1e-9;

# ## The Relationship Between Parametrisations
#
# Much of the time, one can convert between the different parametrisations to obtain
# equivalent ``q(\mathbf{u})``, for a given set of hyperparameters.
# If it's unclear from the above how these parametrisations relate to one another, the
# following should help to crystalise the relationship.
#
# ### Centered vs Non-Centered
#
# Both the `Centered` and `NonCentered` parametrisations are specified by a mean vector `m`
# and covariance matrix `C`, but in slightly different ways.
# The `Centered` parametrisation interprets `m` and `C` as the mean and covariance of
# ``q(\mathbf{u})`` directly, while the `NonCentered` parametrisation inteprets them as the
# mean and covariance of the approximate posterior over
# `ε := cholesky(cov(u)).U' \ (u - mean(u))`.
#
# To see this, consider the following non-centered approximate posterior:
fz = f(z, jitter);
qu_non_centered = MvNormal(randn(M), Matrix{Float64}(I, M, M));
non_centered_approx = SparseVariationalApproximation(NonCentered(), fz, qu_non_centered);

# The equivalent centered parametrisation can be found by multiplying the parameters of
# `qu_non_centered` by the Cholesky factor of the prior covariance:
L = cholesky(Symmetric(cov(fz))).L;
qu_centered = MvNormal(L * mean(qu_non_centered), L * cov(qu_non_centered) * L');
centered_approx = SparseVariationalApproximation(Centered(), fz, qu_centered);

# We can gain some confidence that they're actually the same by querying the approximate
# posterior statistics at some new locations:
q_non_centered = posterior(non_centered_approx)
q_centered = posterior(centered_approx)
@assert mean(q_non_centered(x_pred)) ≈ mean(q_centered(x_pred))
@assert cov(q_non_centered(x_pred)) ≈ cov(q_centered(x_pred))

# ### Pseudo-Observation vs Centered
#
# The relationship between these two parametrisations is only slightly more complicated.
# Consider the following pseudo-observation parametrisation of the approximate posterior:
ŷ = randn(M);
Ŝ = Matrix{Float64}(I, M, M);
pseudo_obs_approx = PseudoObsSparseVariationalApproximation(f, z, Ŝ, ŷ);
q_pseudo_obs = posterior(pseudo_obs_approx);

# The corresponding centered approximation is given via the usual Gaussian conditioning
# formulae:
C = cov(fz);
C_centered = C - C * (cholesky(Symmetric(C + Ŝ)) \ C);
m_centered = mean(fz) + C / cholesky(Symmetric(C + Ŝ)) * (ŷ - mean(fz));
qu_centered = MvNormal(m_centered, Symmetric(C_centered));
centered_approx = SparseVariationalApproximation(Centered(), fz, qu_centered);
q_centered = posterior(centered_approx);

# Again, we can gain some confidence that they're the same by comparing the posterior
# marginal statistics.
@assert mean(q_pseudo_obs(x_pred)) ≈ mean(q_centered(x_pred))
@assert cov(q_pseudo_obs(x_pred)) ≈ cov(q_centered(x_pred))

# While it's always possible to find an approximation using the centered parametrisation
# which is equivalent to a given pseudo-observation parametrisation, the converse is not
# true.
# That is, for a given `C = cov(fz)` and particular choice of covariance matrix `Ĉ` in a
# centered parametrisation, it may not be the case that there exists a positive-definite
# pseudo-observation covariance matrix `Ŝ` such that ``\hat{C} = C - C (C + \hat{S})^{-1} C``.
#
# However, ths is not necessarily a problem: if the likelihood used in the model is
# log-concave then the optimal choice for `Ĉ` can always be represented using this
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a citation for this? 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering about that. I've definitely seen the result floating around (and it's easy enough to prove) -- will have a hunt.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's that easy to prove, just do it here in the docs 😂

# pseudo-observation parametrisation.
# Even when this is not the case, it is not guaruanteed to be the case that the optimal
# choice for `q(u)` lives outside of the family of distributions which can be expressed
# within the pseudo-observation family.

#
# ### Decoupled Pseudo-Observation vs Non-Centered
#
# The relationship here is the most delicate, due to the restriction that
# ``\hat{\mathbf{S}}`` must be diagonal.
# This approximation achieves the optimal approximate posterior when the choice of
# pseudo observational data (``\hat{y}``, ``\hat{\mathbf{S}}``, and ``\mathbf{v}``) equal
# the original observational data.
# When the original observational data involves a non-Gaussian likelihood, this
# approximation family can still obtain the optimal approximate posterior provided that
# ``\mathbf{v}`` lines up with the inputs associated with the original data, ``\mathbf{x}``.
#
# To see this, consider the pseudo-observation approximation which makes use of the
# original observational data (generated at the top of this example):
decoupled_approx = PseudoObsSparseVariationalApproximation(f, z, Σ, x, y);
decoupled_posterior = posterior(decoupled_approx);

# We can get the optimal pseudo-point approximation using standard functionality:
optimal_approx_post = posterior(VFE(f(z, jitter)), f(x, Σ), y);

# The marginal statistics agree:
@assert mean(optimal_approx_post(x_pred)) ≈ mean(decoupled_posterior(x_pred))
@assert cov(optimal_approx_post(x_pred)) ≈ cov(decoupled_posterior(x_pred))

# The reason to think that this parametrisation will do something sensible is this property.
# Obviously when ``\mathbf{v} \neq \mathbf{x}`` the optimal approximate posterior cannot be
# recovered, however, when the hope is that there exists a small pseudo-dataset which gets
# close to the optimum.
Comment on lines +164 to +167
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the advantage of this parametrisation ?


# [^Titsias]: Titsias, M. K. [Variational learning of inducing variables in sparse Gaussian processes](https://proceedings.mlr.press/v5/titsias09a.html)
# [^Gorinova]: Gorinova, Maria and Moore, Dave and Hoffman, Matthew [Automatic Reparameterisation of Probabilistic Programs](http://proceedings.mlr.press/v119/gorinova20a)
# [^Paciorek]: [Paciorek, Christopher Joseph. Nonstationary Gaussian processes for regression and spatial modelling. Diss. Carnegie Mellon University, 2003.](https://www.stat.berkeley.edu/~paciorek/diss/paciorek-thesis.pdf)
3 changes: 2 additions & 1 deletion src/ApproximateGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ include("SparseVariationalApproximationModule.jl")
SparseVariationalApproximation, Centered, NonCentered
@reexport using .SparseVariationalApproximationModule:
DefaultQuadrature, Analytic, GaussHermite, MonteCarlo

@reexport using .SparseVariationalApproximationModule:
PseudoObsSparseVariationalApproximation, ObsCovLikelihood, DecoupledObsCovLikelihood
include("LaplaceApproximationModule.jl")
@reexport using .LaplaceApproximationModule: LaplaceApproximation
@reexport using .LaplaceApproximationModule:
Expand Down
Loading