-
Notifications
You must be signed in to change notification settings - Fork 6
Pseudo-Observation Parametrisations #121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
willtebbutt
wants to merge
10
commits into
master
Choose a base branch
from
wct/pseudo-observation-parametrisations
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
2e529e0
Pseudo-observation parametrisations
willtebbutt 2d91393
Pseudo-obs example
willtebbutt a7a7143
Merge branch 'master' into wct/pseudo-observation-parametrisations
willtebbutt 50d0a3a
Apply suggestions from code review
willtebbutt cf64d24
Apply suggestions from code review
willtebbutt 348d0d0
Fix problems from formatting
willtebbutt 6961c23
Fix formatting
willtebbutt 41e93ba
Bump patch
willtebbutt 87e2c39
Fix things the formatter broke
willtebbutt d47c809
Fix remaining error from formatter
willtebbutt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
[deps] | ||
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" | ||
ApproximateGPs = "298c2ebc-0411-48ad-af38-99e88101b606" | ||
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
DrWatson = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1" | ||
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" | ||
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" | ||
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" | ||
Optim = "429524aa-4258-5aef-a3af-852621145aeb" | ||
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" | ||
ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
# # The Various Pseudo-Point Approximation Parametrisations | ||
# | ||
# ### Note to the reader | ||
# At the time of writing (March 2021) the best way to parametrise the approximate posterior | ||
# remains a surprisingly active area of research. | ||
# If you are reading this and feel that it has become outdated, or was incorrect in the | ||
# first instance, it would be greatly appreciated if you could open an issue to discuss. | ||
# | ||
# | ||
# ## Introduction | ||
# | ||
# This example examines the various ways in which this package supports parametrising the | ||
# approximate posterior when utilising sparse approximations. | ||
# | ||
# All sparse (a.k.a. pseudo-point) approximations in this package utilise an approximate | ||
# posterior over a GP ``f`` of the form | ||
# ```math | ||
# q(f) = q(\mathbf{u}) \, p(f_{\neq \mathbf{u}} | \mathbf{u}) | ||
# ``` | ||
# where samples from ``f`` are functions mapping ``\mathcal{X} \to \mathbb{R}``, | ||
# ``\mathbf{u} := f(\mathbf{z})``, ``\mathbf{z} \in \mathcal{X}^M`` are the pseudo-inputs, | ||
# and ``f_{\neq \mathbf{u}}`` denotes ``f`` at all indices other than those in | ||
# ``\mathbf{z}``.[^Titsias] | ||
# ``\mathbf{u} := q(f(\mathbf{z}))`` is generally restricted to be a multivariate Gaussian, to which end ApproximateGPs presently offers four parametrisations: | ||
# 1. Centered ("Unwhitened"): ``q(\mathbf{u}) = \mathcal{N}(\mathbf{m}, \mathbf{C})``, ``\quad \mathbf{m} \in \mathbb{R}^M`` and positive-definite ``\mathbf{C} \in \mathbb{R}^{M \times M}``, | ||
# 1. Non-Centered ("Whitened"): ``q(\mathbf{u}) = \mathcal{N}(\mathbf{L} \mathbf{m}, \mathbf{L} \mathbf{C} \mathbf{T}^\top)``, ``\quad \mathbf{L} \mathbf{L}^\top = \text{cov}(\mathbf{u})``, | ||
# 1. Pseudo-Observation: ``q(\mathbf{u}) \propto p(\mathbf{u}) \, \mathcal{N}(\hat{\mathbf{y}}; \mathbf{u}, \hat{\mathbf{S}})``, ``\quad \hat{\mathbf{y}} \in \mathbb{R}^M`` and positive-definite ``\hat{\mathbf{S}} \in \mathbb{R}^{M \times M}``, | ||
# 1. Decoupled Pseudo-Observation: ``q(\mathbf{u}) \propto p(\mathbf{u}) \, \mathcal{N}(\hat{\mathbf{y}}; f(\mathbf{v}), \hat{\mathbf{S}})``, ``\quad \hat{\mathbf{y}} \in \mathbb{R}^R``, ``\hat{\mathbf{S}} \in \mathbb{R}^{R \times R}`` is positive-definite and diagonal, and ``\mathbf{v} \in \mathcal{X}^R``. | ||
# | ||
# The choice of parametrization can have a substantial impact on the time it takes for ELBO | ||
# optimization to converge, and which parametrization is better in a particular situation is | ||
# not generally obvious. | ||
# That being said, the `NonCentered` parametrization often converges in fewer iterations | ||
# than the `Centered`, and is widely used, so it is the default. | ||
# | ||
# For a general discussion around the centered vs non-centered, see e.g. [^Gorinova]. | ||
# For a GP-specific discussion, see e.g. section 3.4 of [^Paciorek]. | ||
|
||
# ## Setup | ||
|
||
using AbstractGPs | ||
using ApproximateGPs | ||
using CairoMakie | ||
using Distributions | ||
using Images | ||
using KernelFunctions | ||
using LinearAlgebra | ||
using Optim | ||
using Random | ||
using Zygote | ||
|
||
# A simple GP with inputs on the reals. | ||
f = GP(SEKernel()); | ||
N = 100; | ||
x = range(-3.0, 3.0; length=N); | ||
|
||
# Generate some observations. | ||
Σ = Diagonal(fill(0.1, N)); | ||
y = rand(Xoshiro(123456), f(x, Σ)); | ||
|
||
# Use a handful of pseudo-points. | ||
M = 10; | ||
z = range(-3.5, 3.5; length=M); | ||
|
||
# Other misc. constants that we'll need later: | ||
x_pred = range(-5.0, 5.0; length=300); | ||
jitter = 1e-9; | ||
|
||
# ## The Relationship Between Parametrisations | ||
# | ||
# Much of the time, one can convert between the different parametrisations to obtain | ||
# equivalent ``q(\mathbf{u})``, for a given set of hyperparameters. | ||
# If it's unclear from the above how these parametrisations relate to one another, the | ||
# following should help to crystalise the relationship. | ||
# | ||
# ### Centered vs Non-Centered | ||
# | ||
# Both the `Centered` and `NonCentered` parametrisations are specified by a mean vector `m` | ||
# and covariance matrix `C`, but in slightly different ways. | ||
# The `Centered` parametrisation interprets `m` and `C` as the mean and covariance of | ||
# ``q(\mathbf{u})`` directly, while the `NonCentered` parametrisation inteprets them as the | ||
# mean and covariance of the approximate posterior over | ||
# `ε := cholesky(cov(u)).U' \ (u - mean(u))`. | ||
# | ||
# To see this, consider the following non-centered approximate posterior: | ||
fz = f(z, jitter); | ||
qu_non_centered = MvNormal(randn(M), Matrix{Float64}(I, M, M)); | ||
non_centered_approx = SparseVariationalApproximation(NonCentered(), fz, qu_non_centered); | ||
|
||
# The equivalent centered parametrisation can be found by multiplying the parameters of | ||
# `qu_non_centered` by the Cholesky factor of the prior covariance: | ||
L = cholesky(Symmetric(cov(fz))).L; | ||
qu_centered = MvNormal(L * mean(qu_non_centered), L * cov(qu_non_centered) * L'); | ||
centered_approx = SparseVariationalApproximation(Centered(), fz, qu_centered); | ||
|
||
# We can gain some confidence that they're actually the same by querying the approximate | ||
# posterior statistics at some new locations: | ||
q_non_centered = posterior(non_centered_approx) | ||
q_centered = posterior(centered_approx) | ||
@assert mean(q_non_centered(x_pred)) ≈ mean(q_centered(x_pred)) | ||
@assert cov(q_non_centered(x_pred)) ≈ cov(q_centered(x_pred)) | ||
|
||
# ### Pseudo-Observation vs Centered | ||
# | ||
# The relationship between these two parametrisations is only slightly more complicated. | ||
# Consider the following pseudo-observation parametrisation of the approximate posterior: | ||
ŷ = randn(M); | ||
Ŝ = Matrix{Float64}(I, M, M); | ||
pseudo_obs_approx = PseudoObsSparseVariationalApproximation(f, z, Ŝ, ŷ); | ||
q_pseudo_obs = posterior(pseudo_obs_approx); | ||
|
||
# The corresponding centered approximation is given via the usual Gaussian conditioning | ||
# formulae: | ||
C = cov(fz); | ||
C_centered = C - C * (cholesky(Symmetric(C + Ŝ)) \ C); | ||
m_centered = mean(fz) + C / cholesky(Symmetric(C + Ŝ)) * (ŷ - mean(fz)); | ||
qu_centered = MvNormal(m_centered, Symmetric(C_centered)); | ||
centered_approx = SparseVariationalApproximation(Centered(), fz, qu_centered); | ||
q_centered = posterior(centered_approx); | ||
|
||
# Again, we can gain some confidence that they're the same by comparing the posterior | ||
# marginal statistics. | ||
@assert mean(q_pseudo_obs(x_pred)) ≈ mean(q_centered(x_pred)) | ||
@assert cov(q_pseudo_obs(x_pred)) ≈ cov(q_centered(x_pred)) | ||
|
||
# While it's always possible to find an approximation using the centered parametrisation | ||
# which is equivalent to a given pseudo-observation parametrisation, the converse is not | ||
# true. | ||
# That is, for a given `C = cov(fz)` and particular choice of covariance matrix `Ĉ` in a | ||
# centered parametrisation, it may not be the case that there exists a positive-definite | ||
# pseudo-observation covariance matrix `Ŝ` such that ``\hat{C} = C - C (C + \hat{S})^{-1} C``. | ||
# | ||
# However, ths is not necessarily a problem: if the likelihood used in the model is | ||
# log-concave then the optimal choice for `Ĉ` can always be represented using this | ||
# pseudo-observation parametrisation. | ||
# Even when this is not the case, it is not guaruanteed to be the case that the optimal | ||
# choice for `q(u)` lives outside of the family of distributions which can be expressed | ||
# within the pseudo-observation family. | ||
|
||
# | ||
# ### Decoupled Pseudo-Observation vs Non-Centered | ||
# | ||
# The relationship here is the most delicate, due to the restriction that | ||
# ``\hat{\mathbf{S}}`` must be diagonal. | ||
# This approximation achieves the optimal approximate posterior when the choice of | ||
# pseudo observational data (``\hat{y}``, ``\hat{\mathbf{S}}``, and ``\mathbf{v}``) equal | ||
# the original observational data. | ||
# When the original observational data involves a non-Gaussian likelihood, this | ||
# approximation family can still obtain the optimal approximate posterior provided that | ||
# ``\mathbf{v}`` lines up with the inputs associated with the original data, ``\mathbf{x}``. | ||
# | ||
# To see this, consider the pseudo-observation approximation which makes use of the | ||
# original observational data (generated at the top of this example): | ||
decoupled_approx = PseudoObsSparseVariationalApproximation(f, z, Σ, x, y); | ||
decoupled_posterior = posterior(decoupled_approx); | ||
|
||
# We can get the optimal pseudo-point approximation using standard functionality: | ||
optimal_approx_post = posterior(VFE(f(z, jitter)), f(x, Σ), y); | ||
|
||
# The marginal statistics agree: | ||
@assert mean(optimal_approx_post(x_pred)) ≈ mean(decoupled_posterior(x_pred)) | ||
@assert cov(optimal_approx_post(x_pred)) ≈ cov(decoupled_posterior(x_pred)) | ||
|
||
# The reason to think that this parametrisation will do something sensible is this property. | ||
# Obviously when ``\mathbf{v} \neq \mathbf{x}`` the optimal approximate posterior cannot be | ||
# recovered, however, when the hope is that there exists a small pseudo-dataset which gets | ||
# close to the optimum. | ||
Comment on lines
+164
to
+167
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the advantage of this parametrisation ? |
||
|
||
# [^Titsias]: Titsias, M. K. [Variational learning of inducing variables in sparse Gaussian processes](https://proceedings.mlr.press/v5/titsias09a.html) | ||
# [^Gorinova]: Gorinova, Maria and Moore, Dave and Hoffman, Matthew [Automatic Reparameterisation of Probabilistic Programs](http://proceedings.mlr.press/v119/gorinova20a) | ||
# [^Paciorek]: [Paciorek, Christopher Joseph. Nonstationary Gaussian processes for regression and spatial modelling. Diss. Carnegie Mellon University, 2003.](https://www.stat.berkeley.edu/~paciorek/diss/paciorek-thesis.pdf) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have a citation for this? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering about that. I've definitely seen the result floating around (and it's easy enough to prove) -- will have a hunt.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it's that easy to prove, just do it here in the docs 😂