Solved – Comparing Laplace Approximation and Variational Inference

Does anyone know of any references that look at the relationship between the Laplace approximation and variational inference (with normal approximating distributions)? Namely I'm looking for something like conditions on the distribution being approximated for the two approximations to coincide.

Edit: To give some clarification, suppose you want to approximate some distribution with density $f(theta)$ which you only know up to proportionality. When using the Laplace approximation, you approximate it with the density of a Normal distribution with mean $hat{mu}_1$ and covariance $hat{Sigma}_1$ where $hat{mu}_1=arg max_{theta}f(theta)$ and $hat{Sigma}_1=[-nablanabla log f(theta)mid_{theta=hat{mu}}]^{-1}$. When using variational inference with a normal approximating distribution, you approximate it with the density of a Normal distribution with mean $hat{mu}_2$ and covariance $hat{Sigma}_2$, where $(hat{mu}_2,hat{Sigma}_2)=arg max_{(mu,Sigma)}KL(phi_{(mu,Sigma)}||f)$, $KL$ is the KL-Divergence, and $phi_{(mu,Sigma)}$ denotes a Normal density with mean and covariance $(mu,Sigma)$. Under what conditions do we have $(hat{mu}_1,hat{Sigma}_1)=(hat{mu}_2,hat{Sigma}_2)$?

I am not aware of any general results, but in this paper the authors have some thoughts for Gaussian variational approximations (GVAs) for generalized linear mixed model (GLMMs). Let $vec y$ be the observed outcomes, $X$ be a fixed effect design matrix, $Z$ be a random effect design, denote an unknown random effect $vec U$, and consider a GLMM with densities:

$$ begin{align*} f_{vec Ymidvec U} (vec y;vec u) &= expleft(vec y^top(Xvecbeta + Zvec u) – vec 1^top b(Xvecbeta + Zvec u) + vec 1^top c(vec y)right) \ f_{vec U}(vec u) &= phi^{(K)}(vec u;vec 0, Sigma) \ f(vec y,vec u) &= f_{vec Ymidvec U} (vec y;vec u)f_{vec U}(vec u) end{align*} $$

where I use the same notation as in the paper and $phi^{(K)}$ is a $K$-dimensional multivariate normal distribution density function.

Using a Laplace Approximation

Let

$$ g(vec u) = log f(vec y,vec u). $$

Then we use the approximation

$$ logint exp(g(vec u)) dvec u approx frac K2log{2pi – frac 12loglvert-g''(widehat u)rvert} + g(widehat u) $$

where

$$ widehat u = text{argmax}_{vec u} g(vec u). $$

Using a Gaussian Variational Approximation

The lower bound in the GVA with a mean $vecmu$ and covariance matrix $Lambda$ is:

$$ begin{align*} int exp(g(vec u)) dvec u &approx vec y^top(Xvecbeta + Zvecmu) – vec 1^top B(Xvecbeta + Zvecmu, text{diag}(ZLambda Z^top)) \ &hspace{25pt}+ vec 1^top c(vec y) + frac 12 Big( loglvertSigma^{-1}rvert + loglvertLambdarvert -vecmu^topSigma^{-1}vecmu \ &hspace{25pt} – text{trace}(Sigma^{-1}Lambda) + K Big) \ B(mu,sigma^2) &= int b(sigma x + mu)phi(x) d x end{align*} $$

where $text{diag}(cdot)$ returns a diagonal matrix.

Comparing the Two

Suppose that we can show that $Lambdarightarrow 0$ (the estimated conditional covariance matrix of the random effects tends towards zero). Then the lower bound (disregarding a determinant) tends towards:

$$ begin{align*} int exp(g(vec u)) dvec u &approx vec y^top(Xvecbeta + Zvecmu) – vec 1^top b(Xvecbeta + Zvecmu) \ &hspace{25pt}+ vec 1^top c(vec y) + frac 12 Big( loglvertSigma^{-1}rvert -vecmu^topSigma^{-1}vecmu + KBig) \ &= g(vecmu) + dots end{align*} $$

where the dots do not depend on the model parameters, $vecbeta$ and $Sigma$. Thus, maximizing over $vecmu$ yields $vecmurightarrow widehat u$. Then the only difference between the Laplace approximation and the GVA is a

$$ – frac 12loglvert -g''(widehat u)rvert $$

term. We have that

$$ -g''(widehat u) = Sigma^{-1} + Z^top b''(Xvecbeta + Zvec u)Z $$

where the derivatives are with respect to $veceta = Xvecbeta + Zvec u$. This does not tend towards zero as the conditional distribution of the random effects becomes more peaked. However, still very hand wavy, it may cancel out with the

$$ frac 12loglvertLambdarvert = -frac 12loglvertLambda^{-1}rvert $$

term we disregarded in the lower bound. The first order condition for $Lambda$ is:

$$ Lambda^{-1} = Sigma^{-1} + Z^top B^{(2)}(Xvecbeta + Zvecmu, text{diag}(ZLambda Z^top)Z $$

where

$$ B^{(2)}(mu,sigma^2) = int b''(sigma x+ mu)phi(x) dx. $$

Thus, if $vecmu approx widehat u$ and $Lambda approx 0$ then:

$$ Lambda^{-1} approx Sigma^{-1} + Z^top b''(Xvecbeta + Zvec u)Z $$

and the Laplace approximation and the GVA yield the same approximation of the log marginal likelihood.

Notes

Do also see the annals paper Ryan Warnick mentions.

Similar Posts:

Rate this post

Leave a Comment