diff --git a/Project.toml b/Project.toml index 4d78bbd4..ef6873d5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ApproximateGPs" uuid = "298c2ebc-0411-48ad-af38-99e88101b606" authors = ["JuliaGaussianProcesses Team"] -version = "0.3.4" +version = "0.3.5" [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" diff --git a/docs/src/userguide.md b/docs/src/userguide.md index 1ca7e9db..3476d72f 100644 --- a/docs/src/userguide.md +++ b/docs/src/userguide.md @@ -46,31 +46,9 @@ The approximate posterior constructed above will be a very poor approximation, s ```julia elbo(SparseVariationalApproximation(fz, q), fx, y) ``` -A detailed example of how to carry out such optimisation is given in [Regression: Sparse Variational Gaussian Process for Stochastic Optimisation with Flux.jl](@ref). For an example of non-conjugate inference, see [Classification: Sparse Variational Approximation for Non-Conjugate Likelihoods with Optim's L-BFGS](@ref). # Available Parametrizations -Two parametrizations of `q(u)` are presently available: [`Centered`](@ref) and [`NonCentered`](@ref). -The `Centered` parametrization expresses `q(u)` directly in terms of its mean and covariance. -The `NonCentered` parametrization instead parametrizes the mean and covariance of -`ε := cholesky(cov(u)).U' \ (u - mean(u))`. -These parametrizations are also known respectively as "Unwhitened" and "Whitened". - -The choice of parametrization can have a substantial impact on the time it takes for ELBO -optimization to converge, and which parametrization is better in a particular situation is -not generally obvious. -That being said, the `NonCentered` parametrization often converges in fewer iterations, so it is the default -- -it is what is used in all of the examples above. - -If you require a particular parametrization, simply use the 3-argument version of the -approximation constructor: -```julia -SparseVariationalApproximation(Centered(), fz, q) -SparseVariationalApproximation(NonCentered(), fz, q) -``` - -For a general discussion around these two parametrizations, see e.g. [^Gorinova]. -For a GP-specific discussion, see e.g. section 3.4 of [^Paciorek]. - -[^Gorinova]: Gorinova, Maria and Moore, Dave and Hoffman, Matthew [Automatic Reparameterisation of Probabilistic Programs](http://proceedings.mlr.press/v119/gorinova20a) -[^Paciorek]: [Paciorek, Christopher Joseph. Nonstationary Gaussian processes for regression and spatial modelling. Diss. Carnegie Mellon University, 2003.](https://www.stat.berkeley.edu/~paciorek/diss/paciorek-thesis.pdf) +There are various ways to parametrise the approximate posterior. +See [The Various Pseudo-Point Approximation Parametrisations](@ref) for more info and +worked examples. diff --git a/examples/d-sparse-parametrisations/Project.toml b/examples/d-sparse-parametrisations/Project.toml new file mode 100644 index 00000000..746832db --- /dev/null +++ b/examples/d-sparse-parametrisations/Project.toml @@ -0,0 +1,17 @@ +[deps] +AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" +ApproximateGPs = "298c2ebc-0411-48ad-af38-99e88101b606" +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DrWatson = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1" +Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" +KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" +Optim = "429524aa-4258-5aef-a3af-852621145aeb" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/examples/d-sparse-parametrisations/script.jl b/examples/d-sparse-parametrisations/script.jl new file mode 100644 index 00000000..0ce6668d --- /dev/null +++ b/examples/d-sparse-parametrisations/script.jl @@ -0,0 +1,171 @@ +# # The Various Pseudo-Point Approximation Parametrisations +# +# ### Note to the reader +# At the time of writing (March 2021) the best way to parametrise the approximate posterior +# remains a surprisingly active area of research. +# If you are reading this and feel that it has become outdated, or was incorrect in the +# first instance, it would be greatly appreciated if you could open an issue to discuss. +# +# +# ## Introduction +# +# This example examines the various ways in which this package supports parametrising the +# approximate posterior when utilising sparse approximations. +# +# All sparse (a.k.a. pseudo-point) approximations in this package utilise an approximate +# posterior over a GP ``f`` of the form +# ```math +# q(f) = q(\mathbf{u}) \, p(f_{\neq \mathbf{u}} | \mathbf{u}) +# ``` +# where samples from ``f`` are functions mapping ``\mathcal{X} \to \mathbb{R}``, +# ``\mathbf{u} := f(\mathbf{z})``, ``\mathbf{z} \in \mathcal{X}^M`` are the pseudo-inputs, +# and ``f_{\neq \mathbf{u}}`` denotes ``f`` at all indices other than those in +# ``\mathbf{z}``.[^Titsias] +# ``\mathbf{u} := q(f(\mathbf{z}))`` is generally restricted to be a multivariate Gaussian, to which end ApproximateGPs presently offers four parametrisations: +# 1. Centered ("Unwhitened"): ``q(\mathbf{u}) = \mathcal{N}(\mathbf{m}, \mathbf{C})``, ``\quad \mathbf{m} \in \mathbb{R}^M`` and positive-definite ``\mathbf{C} \in \mathbb{R}^{M \times M}``, +# 1. Non-Centered ("Whitened"): ``q(\mathbf{u}) = \mathcal{N}(\mathbf{L} \mathbf{m}, \mathbf{L} \mathbf{C} \mathbf{T}^\top)``, ``\quad \mathbf{L} \mathbf{L}^\top = \text{cov}(\mathbf{u})``, +# 1. Pseudo-Observation: ``q(\mathbf{u}) \propto p(\mathbf{u}) \, \mathcal{N}(\hat{\mathbf{y}}; \mathbf{u}, \hat{\mathbf{S}})``, ``\quad \hat{\mathbf{y}} \in \mathbb{R}^M`` and positive-definite ``\hat{\mathbf{S}} \in \mathbb{R}^{M \times M}``, +# 1. Decoupled Pseudo-Observation: ``q(\mathbf{u}) \propto p(\mathbf{u}) \, \mathcal{N}(\hat{\mathbf{y}}; f(\mathbf{v}), \hat{\mathbf{S}})``, ``\quad \hat{\mathbf{y}} \in \mathbb{R}^R``, ``\hat{\mathbf{S}} \in \mathbb{R}^{R \times R}`` is positive-definite and diagonal, and ``\mathbf{v} \in \mathcal{X}^R``. +# +# The choice of parametrization can have a substantial impact on the time it takes for ELBO +# optimization to converge, and which parametrization is better in a particular situation is +# not generally obvious. +# That being said, the `NonCentered` parametrization often converges in fewer iterations +# than the `Centered`, and is widely used, so it is the default. +# +# For a general discussion around the centered vs non-centered, see e.g. [^Gorinova]. +# For a GP-specific discussion, see e.g. section 3.4 of [^Paciorek]. + +# ## Setup + +using AbstractGPs +using ApproximateGPs +using CairoMakie +using Distributions +using Images +using KernelFunctions +using LinearAlgebra +using Optim +using Random +using Zygote + +# A simple GP with inputs on the reals. +f = GP(SEKernel()); +N = 100; +x = range(-3.0, 3.0; length=N); + +# Generate some observations. +Σ = Diagonal(fill(0.1, N)); +y = rand(Xoshiro(123456), f(x, Σ)); + +# Use a handful of pseudo-points. +M = 10; +z = range(-3.5, 3.5; length=M); + +# Other misc. constants that we'll need later: +x_pred = range(-5.0, 5.0; length=300); +jitter = 1e-9; + +# ## The Relationship Between Parametrisations +# +# Much of the time, one can convert between the different parametrisations to obtain +# equivalent ``q(\mathbf{u})``, for a given set of hyperparameters. +# If it's unclear from the above how these parametrisations relate to one another, the +# following should help to crystalise the relationship. +# +# ### Centered vs Non-Centered +# +# Both the `Centered` and `NonCentered` parametrisations are specified by a mean vector `m` +# and covariance matrix `C`, but in slightly different ways. +# The `Centered` parametrisation interprets `m` and `C` as the mean and covariance of +# ``q(\mathbf{u})`` directly, while the `NonCentered` parametrisation inteprets them as the +# mean and covariance of the approximate posterior over +# `ε := cholesky(cov(u)).U' \ (u - mean(u))`. +# +# To see this, consider the following non-centered approximate posterior: +fz = f(z, jitter); +qu_non_centered = MvNormal(randn(M), Matrix{Float64}(I, M, M)); +non_centered_approx = SparseVariationalApproximation(NonCentered(), fz, qu_non_centered); + +# The equivalent centered parametrisation can be found by multiplying the parameters of +# `qu_non_centered` by the Cholesky factor of the prior covariance: +L = cholesky(Symmetric(cov(fz))).L; +qu_centered = MvNormal(L * mean(qu_non_centered), L * cov(qu_non_centered) * L'); +centered_approx = SparseVariationalApproximation(Centered(), fz, qu_centered); + +# We can gain some confidence that they're actually the same by querying the approximate +# posterior statistics at some new locations: +q_non_centered = posterior(non_centered_approx) +q_centered = posterior(centered_approx) +@assert mean(q_non_centered(x_pred)) ≈ mean(q_centered(x_pred)) +@assert cov(q_non_centered(x_pred)) ≈ cov(q_centered(x_pred)) + +# ### Pseudo-Observation vs Centered +# +# The relationship between these two parametrisations is only slightly more complicated. +# Consider the following pseudo-observation parametrisation of the approximate posterior: +ŷ = randn(M); +Ŝ = Matrix{Float64}(I, M, M); +pseudo_obs_approx = PseudoObsSparseVariationalApproximation(f, z, Ŝ, ŷ); +q_pseudo_obs = posterior(pseudo_obs_approx); + +# The corresponding centered approximation is given via the usual Gaussian conditioning +# formulae: +C = cov(fz); +C_centered = C - C * (cholesky(Symmetric(C + Ŝ)) \ C); +m_centered = mean(fz) + C / cholesky(Symmetric(C + Ŝ)) * (ŷ - mean(fz)); +qu_centered = MvNormal(m_centered, Symmetric(C_centered)); +centered_approx = SparseVariationalApproximation(Centered(), fz, qu_centered); +q_centered = posterior(centered_approx); + +# Again, we can gain some confidence that they're the same by comparing the posterior +# marginal statistics. +@assert mean(q_pseudo_obs(x_pred)) ≈ mean(q_centered(x_pred)) +@assert cov(q_pseudo_obs(x_pred)) ≈ cov(q_centered(x_pred)) + +# While it's always possible to find an approximation using the centered parametrisation +# which is equivalent to a given pseudo-observation parametrisation, the converse is not +# true. +# That is, for a given `C = cov(fz)` and particular choice of covariance matrix `Ĉ` in a +# centered parametrisation, it may not be the case that there exists a positive-definite +# pseudo-observation covariance matrix `Ŝ` such that ``\hat{C} = C - C (C + \hat{S})^{-1} C``. +# +# However, ths is not necessarily a problem: if the likelihood used in the model is +# log-concave then the optimal choice for `Ĉ` can always be represented using this +# pseudo-observation parametrisation. +# Even when this is not the case, it is not guaruanteed to be the case that the optimal +# choice for `q(u)` lives outside of the family of distributions which can be expressed +# within the pseudo-observation family. + +# +# ### Decoupled Pseudo-Observation vs Non-Centered +# +# The relationship here is the most delicate, due to the restriction that +# ``\hat{\mathbf{S}}`` must be diagonal. +# This approximation achieves the optimal approximate posterior when the choice of +# pseudo observational data (``\hat{y}``, ``\hat{\mathbf{S}}``, and ``\mathbf{v}``) equal +# the original observational data. +# When the original observational data involves a non-Gaussian likelihood, this +# approximation family can still obtain the optimal approximate posterior provided that +# ``\mathbf{v}`` lines up with the inputs associated with the original data, ``\mathbf{x}``. +# +# To see this, consider the pseudo-observation approximation which makes use of the +# original observational data (generated at the top of this example): +decoupled_approx = PseudoObsSparseVariationalApproximation(f, z, Σ, x, y); +decoupled_posterior = posterior(decoupled_approx); + +# We can get the optimal pseudo-point approximation using standard functionality: +optimal_approx_post = posterior(VFE(f(z, jitter)), f(x, Σ), y); + +# The marginal statistics agree: +@assert mean(optimal_approx_post(x_pred)) ≈ mean(decoupled_posterior(x_pred)) +@assert cov(optimal_approx_post(x_pred)) ≈ cov(decoupled_posterior(x_pred)) + +# The reason to think that this parametrisation will do something sensible is this property. +# Obviously when ``\mathbf{v} \neq \mathbf{x}`` the optimal approximate posterior cannot be +# recovered, however, when the hope is that there exists a small pseudo-dataset which gets +# close to the optimum. + +# [^Titsias]: Titsias, M. K. [Variational learning of inducing variables in sparse Gaussian processes](https://proceedings.mlr.press/v5/titsias09a.html) +# [^Gorinova]: Gorinova, Maria and Moore, Dave and Hoffman, Matthew [Automatic Reparameterisation of Probabilistic Programs](http://proceedings.mlr.press/v119/gorinova20a) +# [^Paciorek]: [Paciorek, Christopher Joseph. Nonstationary Gaussian processes for regression and spatial modelling. Diss. Carnegie Mellon University, 2003.](https://www.stat.berkeley.edu/~paciorek/diss/paciorek-thesis.pdf) diff --git a/src/ApproximateGPs.jl b/src/ApproximateGPs.jl index 6d5e4d95..515dc0a3 100644 --- a/src/ApproximateGPs.jl +++ b/src/ApproximateGPs.jl @@ -15,7 +15,8 @@ include("SparseVariationalApproximationModule.jl") SparseVariationalApproximation, Centered, NonCentered @reexport using .SparseVariationalApproximationModule: DefaultQuadrature, Analytic, GaussHermite, MonteCarlo - +@reexport using .SparseVariationalApproximationModule: + PseudoObsSparseVariationalApproximation, ObsCovLikelihood, DecoupledObsCovLikelihood include("LaplaceApproximationModule.jl") @reexport using .LaplaceApproximationModule: LaplaceApproximation @reexport using .LaplaceApproximationModule: diff --git a/src/SparseVariationalApproximationModule.jl b/src/SparseVariationalApproximationModule.jl index 0b6fdadb..48f6c4f5 100644 --- a/src/SparseVariationalApproximationModule.jl +++ b/src/SparseVariationalApproximationModule.jl @@ -2,7 +2,8 @@ module SparseVariationalApproximationModule using ..API -export SparseVariationalApproximation, Centered, NonCentered +export SparseVariationalApproximation, + Centered, NonCentered, PseudoObsSparseVariationalApproximation using ..ApproximateGPs: _chol_cov, _cov using Distributions @@ -29,6 +30,13 @@ using GPLikelihoods: GaussianLikelihood export DefaultQuadrature, Analytic, GaussHermite, MonteCarlo include("expected_loglik.jl") +""" + abstract type AbstractSparseVariationalApproximation end + +Supertype for sparse variational approximations. +""" +abstract type AbstractSparseVariationalApproximation end + @doc raw""" Centered() @@ -60,7 +68,9 @@ See also [`Centered`](@ref). """ struct NonCentered end -struct SparseVariationalApproximation{Parametrization,Tfz<:FiniteGP,Tq<:AbstractMvNormal} +struct SparseVariationalApproximation{ + Parametrization,Tfz<:FiniteGP,Tq<:AbstractMvNormal +} <: AbstractSparseVariationalApproximation fz::Tfz q::Tq end @@ -191,14 +201,14 @@ function AbstractGPs.posterior(sva::SparseVariationalApproximation{NonCentered}) end function AbstractGPs.posterior( - sva::SparseVariationalApproximation, fx::FiniteGP, ::AbstractVector{<:Real} + sva::AbstractSparseVariationalApproximation, fx::FiniteGP, ::AbstractVector{<:Real} ) @assert sva.fz.f === fx.f return posterior(sva) end function AbstractGPs.posterior( - sva::SparseVariationalApproximation, lfx::LatentFiniteGP, ::Any + sva::AbstractSparseVariationalApproximation, lfx::LatentFiniteGP, ::Any ) @assert sva.fz.f === lfx.fx.f return posterior(sva) @@ -210,7 +220,7 @@ end # function Statistics.mean( - f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector + f::ApproxPosteriorGP{<:AbstractSparseVariationalApproximation}, x::AbstractVector ) return mean(f.prior, x) + cov(f.prior, x, inducing_points(f)) * f.data.α end @@ -225,21 +235,21 @@ end _A(f, x) = first(_A_and_Kuf(f, x)) function Statistics.cov( - f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector + f::ApproxPosteriorGP{<:AbstractSparseVariationalApproximation}, x::AbstractVector ) A = _A(f, x) return cov(f.prior, x) - At_A(A) + At_A(f.data.B' * A) end function Statistics.var( - f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector + f::ApproxPosteriorGP{<:AbstractSparseVariationalApproximation}, x::AbstractVector ) A = _A(f, x) return var(f.prior, x) - diag_At_A(A) + diag_At_A(f.data.B' * A) end function StatsBase.mean_and_cov( - f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector + f::ApproxPosteriorGP{<:AbstractSparseVariationalApproximation}, x::AbstractVector ) A, Kuf = _A_and_Kuf(f, x) μ = mean(f.prior, x) + Kuf' * f.data.α @@ -248,7 +258,7 @@ function StatsBase.mean_and_cov( end function StatsBase.mean_and_var( - f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector + f::ApproxPosteriorGP{<:AbstractSparseVariationalApproximation}, x::AbstractVector ) A, Kuf = _A_and_Kuf(f, x) μ = mean(f.prior, x) + Kuf' * f.data.α @@ -257,7 +267,7 @@ function StatsBase.mean_and_var( end function Statistics.cov( - f::ApproxPosteriorGP{<:SparseVariationalApproximation}, + f::ApproxPosteriorGP{<:AbstractSparseVariationalApproximation}, x::AbstractVector, y::AbstractVector, ) @@ -278,14 +288,19 @@ inducing_points(f::ApproxPosteriorGP{<:SparseVariationalApproximation}) = f.appr # function API.approx_lml( - sva::SparseVariationalApproximation, l_fx::Union{FiniteGP,LatentFiniteGP}, ys; kwargs... + sva::AbstractSparseVariationalApproximation, + l_fx::Union{FiniteGP,LatentFiniteGP}, + ys; + kwargs..., ) return AbstractGPs.elbo(sva, l_fx, ys; kwargs...) end +_get_prior(approx::SparseVariationalApproximation) = approx.fz.f + """ elbo( - sva::SparseVariationalApproximation, + sva::AbstractSparseVariationalApproximation, fx::FiniteGP, y::AbstractVector{<:Real}; num_data=length(y), @@ -311,18 +326,18 @@ variational Gaussian process classification." Artificial Intelligence and Statistics. PMLR, 2015. """ function AbstractGPs.elbo( - sva::SparseVariationalApproximation, + sva::AbstractSparseVariationalApproximation, fx::FiniteGP{<:AbstractGP,<:AbstractVector,<:Diagonal{<:Real,<:Fill}}, y::AbstractVector{<:Real}; num_data=length(y), quadrature=DefaultQuadrature(), ) - @assert sva.fz.f === fx.f + @assert _get_prior(sva) === fx.f return _elbo(quadrature, sva, fx, y, GaussianLikelihood(fx.Σy[1]), num_data) end function AbstractGPs.elbo( - ::SparseVariationalApproximation, ::FiniteGP, ::AbstractVector; kwargs... + ::AbstractSparseVariationalApproximation, ::FiniteGP, ::AbstractVector; kwargs... ) return error( "The observation noise fx.Σy must be homoscedastic.\n", @@ -333,7 +348,7 @@ end """ elbo( - sva::SparseVariationalApproximation, + sva::AbstractSparseVariationalApproximation, lfx::LatentFiniteGP, y::AbstractVector; num_data=length(y), @@ -343,26 +358,26 @@ end Compute the ELBO for a LatentGP with a possibly non-conjugate likelihood. """ function AbstractGPs.elbo( - sva::SparseVariationalApproximation, + sva::AbstractSparseVariationalApproximation, lfx::LatentFiniteGP, y::AbstractVector; num_data=length(y), quadrature=DefaultQuadrature(), ) - @assert sva.fz.f === lfx.fx.f + @assert _get_prior(sva) === lfx.fx.f return _elbo(quadrature, sva, lfx.fx, y, lfx.lik, num_data) end # Compute the common elements of the ELBO function _elbo( quadrature::QuadratureMethod, - sva::SparseVariationalApproximation, + sva::AbstractSparseVariationalApproximation, fx::FiniteGP, y::AbstractVector, lik, num_data::Integer, ) - @assert sva.fz.f === fx.f + @assert _get_prior(sva) === fx.f f_post = posterior(sva) q_f = marginals(f_post(fx.x)) @@ -386,4 +401,159 @@ function _prior_kl(sva::SparseVariationalApproximation{NonCentered}) return (trace_term + m_ε'm_ε - length(m_ε) - logdet(C_ε)) / 2 end +# Pseudo-Observation Parametrisations of q(u). + +@doc raw""" + PseudoObsSparseVariationalApproximation( + likelihood, f::AbstractGP, z::AbstractVector + ) + +Parametrises `q(f(z))`, the approximate posterior at `f(z)`, using a surrogate likelihood, +`likelihood`: `q(f(z)) ∝ p(f(z)) likelihood(f(z))`. +""" +struct PseudoObsSparseVariationalApproximation{ + Tlikelihood,Tf<:AbstractGP,Tz<:AbstractVector +} <: AbstractSparseVariationalApproximation + likelihood::Tlikelihood + f::Tf + z::Tz +end + +_get_prior(approx::PseudoObsSparseVariationalApproximation) = approx.f + +@doc raw""" + ObsCovLikelihood(S::AbstractMatrix{<:Real}, y::AbstractVector{<:Real}) + +Chooses `likelihood(u) = N(y; u, S)`. `length(y)` must be equal to the number of +pseudo-points utilised in the sparse variational approximation. +""" +struct ObsCovLikelihood{TS<:AbstractMatrix{<:Real},Ty<:AbstractVector{<:Real}} + S::TS + y::Ty +end + +@doc raw""" + PseudoObsSparseVariationalApproximation( + f::AbstractGP, + z::AbstractVector, + S::AbstractMatrix{<:Real}, + y::AbstractVector{<:Real}, + ) + +Convenience constuctor. +Equivalent to +```julia +PseudoObsSparseVariationalApproximation(ObsCovLikelihood(S, y), f, z) +``` +""" +function PseudoObsSparseVariationalApproximation( + f::AbstractGP, z::AbstractVector, S::AbstractMatrix{<:Real}, y::AbstractVector{<:Real} +) + return PseudoObsSparseVariationalApproximation(ObsCovLikelihood(S, y), f, z) +end + +function AbstractGPs.posterior( + approx::PseudoObsSparseVariationalApproximation{<:ObsCovLikelihood} +) + f = approx.f + z = approx.z + y = approx.likelihood.y + S = approx.likelihood.S + return posterior(f(z, S), y) +end + +function _prior_kl(approx::PseudoObsSparseVariationalApproximation{<:ObsCovLikelihood}) + f = approx.f + z = approx.z + y = approx.likelihood.y + S = approx.likelihood.S + + # log marginal probability of pseudo-observations. + logp_pseudo_obs = logpdf(f(z, S), y) + + # pseudo-reconstruction term. + m, C = mean_and_cov(posterior(approx)(z)) + S_chol = cholesky(AbstractGPs._symmetric(S)) + quad_form = sum(abs2, S_chol.U' \ (y - m)) + pseudo_lik = -(length(y) * AbstractGPs.log2π + logdet(S_chol) + quad_form) / 2 + trace_term = tr(S_chol \ C) / 2 + return -logp_pseudo_obs + pseudo_lik - trace_term +end + +@doc raw""" + DecoupledObsCovLikelihood( + S::AbstractMatrix{<:Real}, v::AbstractVector, y::AbstractVector{<:Real} + ) + +Chooses `likelihood(u) = N(y; f(v), S)` where `length(y)` need not be equal to the number +of pseudo-points, where `f` is the GP to which this likelihood specifies the approximate +posterior over `f(z)`. +""" +struct DecoupledObsCovLikelihood{ + TS<:Diagonal{<:Real},Tv<:AbstractVector,Ty<:AbstractVector{<:Real} +} + S::TS + v::Tv + y::Ty +end + +@doc raw""" + PseudoObsSparseVariationalApproximation( + f::AbstractGP, + z::AbstractVector, + S::Diagonal{<:Real}, + v::AbstractVector, + y::AbstractVector{<:Real}, + ) + +Convenience constructor. +Equivalent to +```julia +PseudoObsSparseVariationalApproximation(DecoupledObsCovLikelihood(S, v, y), f, z) +``` +""" +function PseudoObsSparseVariationalApproximation( + f::AbstractGP, + z::AbstractVector, + S::Diagonal{<:Real}, + v::AbstractVector, + y::AbstractVector{<:Real}, +) + return PseudoObsSparseVariationalApproximation(DecoupledObsCovLikelihood(S, v, y), f, z) +end + +function AbstractGPs.posterior( + approx::PseudoObsSparseVariationalApproximation{<:DecoupledObsCovLikelihood} +) + f = approx.f + z = approx.z + y = approx.likelihood.y + S = approx.likelihood.S + v = approx.likelihood.v + return posterior(AbstractGPs.VFE(f(z, 1e-9)), f(v, S), y) +end + +function _prior_kl( + approx::PseudoObsSparseVariationalApproximation{<:DecoupledObsCovLikelihood} +) + f = approx.f + z = approx.z + y = approx.likelihood.y + S = approx.likelihood.S + v = approx.likelihood.v + + # log marginal probability of pseudo-observations. Utilises DTC code. + logp_pseudo_obs = AbstractGPs.dtc(AbstractGPs.VFE(f(z)), f(v, S), y) + + # pseudo-reconstruction term. + m̂, Ĉ = mean_and_cov(posterior(approx)(z, 1e-18)) + At = cholesky(AbstractGPs._symmetric(cov(f(z, 1e-18)))) \ cov(f, z, v) + m = mean(f, v) + At' * (m̂ - mean(f, z)) + pseudo_loglik = sum(map((m, s, y) -> logpdf(Normal(m, sqrt(s)), y), m, diag(S), y)) + pseudo_trace_term = sum(Ĉ .* (At * (S \ At'))) / 2 + pseudo_reconstruction = (pseudo_loglik - pseudo_trace_term) + + return -logp_pseudo_obs + pseudo_reconstruction +end + end diff --git a/test/SparseVariationalApproximationModule.jl b/test/SparseVariationalApproximationModule.jl index 3deb7f34..e31ae49b 100644 --- a/test/SparseVariationalApproximationModule.jl +++ b/test/SparseVariationalApproximationModule.jl @@ -190,4 +190,96 @@ @test all(isapprox.(cov(gpr_post, x), cov(svgp_post, x), atol=1e-4)) end end + @testset "PseudoObs" begin + rng = Xoshiro(123456) + + # Generate data. + f = GP(sin, SEKernel()) + x = range(-5.0, 5.0; length=11) + s = 0.1 + y = rand(rng, f(x, s)) + + z = range(-6.0, 6.0; length=7) + + @testset "Coupled Formulation" begin + + # Generate pseudo-data. + ŷ = randn(rng, length(z)) + _S = randn(rng, length(z), length(z)) + Ŝ = _S * _S' + I + + # Construct approximate posterior. + approx = ApproximateGPs.SparseVariationalApproximationModule.PseudoObsSparseVariationalApproximation( + f, z, Ŝ, ŷ + ) + + approx = ApproximateGPs.SparseVariationalApproximationModule.PseudoObsSparseVariationalApproximation( + f, z, Ŝ, ŷ + ) + approx_posterior = posterior(approx) + AbstractGPs.TestUtils.test_internal_abstractgps_interface( + rng, approx_posterior, x, z + ) + + # Check that the posterior is close to an equivalent Centered approximation. + @testset "compare against equivalent centered" begin + qu = approx_posterior(z, 1e-12) + approx_centered = SparseVariationalApproximation( + Centered(), f(z, 1e-12), qu + ) + approx_post_centered = posterior(approx_centered) + approx_centered = SparseVariationalApproximation( + Centered(), f(z, 1e-12), qu + ) + approx_post_x = approx_posterior(x, s) + approx_post_centered_x = approx_post_centered(x, s) + @test mean(approx_post_x) ≈ mean(approx_post_centered_x) + @test cov(approx_post_x) ≈ cov(approx_post_centered_x) + @test elbo(approx, f(x, s), y) ≈ elbo(approx_centered, f(x, s), y) + end + + # Check that Zygote is able to run. Assume correctness of result. + Zygote.gradient(elbo, approx, f(x, s), y) + end + @testset "Decoupled Formulation" begin + + # Generate pseudo-data. + v = range(-5.0, 5.0; length=9) + ŷ = randn(rng, length(v)) + Ŝ = Diagonal(rand(rng, length(v)) .+ 0.1) + + # Construct approximate posterior. + approx = ApproximateGPs.SparseVariationalApproximationModule.PseudoObsSparseVariationalApproximation( + f, z, Ŝ, v, ŷ + ) + + approx = ApproximateGPs.SparseVariationalApproximationModule.PseudoObsSparseVariationalApproximation( + f, z, Ŝ, v, ŷ + ) + approx_posterior = posterior(approx) + AbstractGPs.TestUtils.test_internal_abstractgps_interface( + rng, approx_posterior, x, z + ) + + # Check that the posterior is close to an equivalent Centered approximation. + @testset "compare against equivalent centered" begin + qu = approx_posterior(z, 1e-12) + approx_centered = SparseVariationalApproximation( + Centered(), f(z, 1e-12), qu + ) + approx_post_centered = posterior(approx_centered) + approx_centered = SparseVariationalApproximation( + Centered(), f(z, 1e-12), qu + ) + approx_post_x = approx_posterior(x, s) + approx_post_centered_x = approx_post_centered(x, s) + @test mean(approx_post_x) ≈ mean(approx_post_centered_x) + @test cov(approx_post_x) ≈ cov(approx_post_centered_x) + @test elbo(approx, f(x, s), y) ≈ elbo(approx_centered, f(x, s), y) + end + + # Check that Zygote is able to run. Assume correctness of result. + Zygote.gradient(elbo, approx, f(x, s), y) + end + end end