I had the hardest time trying to understand variational inference. All of the presentations I’ve seen (MacKay, Bishop, Wikipedia, Gelman’s draft for the third edition of Bayesian Data Analysis) are deeply tied up with the details of a particular model being fit. I wanted to see the algorithm and get the big picture before being overwhelmed with multivariate exponential family gymnastics.
Bayesian Posterior Inference
In the Bayesian setting (see my earlier post, What is Bayesian Inference?), we have a joint probability model for data
and parameters
, usually factored as the product of a likelihood and prior term,
. Given some observed data
, Bayesian predictive inference is based on the posterior density
of the the unknown parameter vector
given observed data vector
. Thus we need to be able to estimate the posterior density
to carry out Bayesian inference. Note that the posterior is a whole density function—we’re not just after a point estimate as in maximum likelihood estimation.
Mean-Field Approximation
Variational inference approximates the Bayesian posterior density with a (simpler) density
parameterized by some new parameters
. The mean-field form of variational inference factors the approximating density
by component of
, as
.
I’m going to put off actually defining the terms until we see how they’re used in the variational inference algorithm.
What Variational Inference Does
The variational inference algorithm finds the value for the parameters
of the approximation which minimizes the Kullback-Leibler divergence of
from
,
.
The key idea here is that variational inference reduces posterior estimation to an optimization problem. Optimization is typically much faster than approaches to posterior estimation such as Markov chain Monte Carlo (MCMC).
The main disadvantage of variational inference is that the posterior is only approximated (though as MacKay points out, just about any approximation is better than a delta function at a point estimate!). In particular, variational methods systematically underestimate posterior variance because of the direction of the KL divergence that is minimized. Expectation propagation (EP) also converts posterior fitting to optimization of KL divergence, but EP uses the opposite direction of KL divergence, which leads to overestimation of posterior variance.
Variational Inference Algorithm
Given the Bayesian model , observed data
, and functional terms
making up the approximation of the posterior
, the variational inference algorithm is:
-
-
-
-
-
.
-
-
-
The inner expectation is a function of returning a single non-negative value, defined by
Despite the suggestive factorization of and the coordinate-wise nature of the algorithm, variational inference does not simply approximate the posterior marginals
independently.
Defining the Approximating Densities
The trick is to choose the approximating factors so that the we can compute parameter values such that
. Finding such approximating terms
given a posterior
is an art form unto itself. It’s much easier for models with conjugate priors. Bishop or MacKay’s books and the Wikipedia present calculations for a wide range of exponential-family models.
What if My Model is not Conjugate?
Unfortunately, I almost never work with conjugate priors (and even if I did, I’m not much of a hand at exponential-family algebra). Therefore, the following paper just got bumped to the top of my must understand queue:
- Wang, Chong and David M. Blei. 2012–2013. Variational Inference in Nonconjugate Models. arXiv 1209.4360.
It’s great having Dave down the hall on sabbatical this year — one couldn’t ask for a better stand in for Matt Hoffman. They are both insanely good at on-the-fly explanations at the blackboard (I love that we still have real chalk and high quality boards).
March 25, 2013 at 1:36 pm |
Thanks, Bob. Nice overview. Would it make any sense to try to minimize some symmetric divergence, such as the Jensen-Shannon divergence?
March 25, 2013 at 2:36 pm |
It certainly makes sense, but the issue is whether you can come up with some computable way to do it. The trick to variational inference is that the hairy integral involved in the componentwise hill climbing can be solved for conjugate priors and approximated elsewhere.
March 25, 2013 at 1:56 pm |
I am actually not a fun of any talor expansion based approximation methods including the reference above. They simly break the lower bound condition (except first order approximation).
Actually, there is another paper written by some Japanese researcher that I cannot recall the title. They show that mean field (including non conjugate models) can see from the dual problem, is to minimize bregman divergence block coordinate wise.
March 25, 2013 at 3:51 pm |
With power-EP method it is possible to use alpha-divergence which includes both KL divergences and symmetric Hellinger distance as special cases.
March 31, 2013 at 6:36 pm |
i thought the presentation in the Koller and Friedman textbook was helpful for variational inference — a bit less tied up with a particular model if i remember right. The Murphy textbook also has a general derivation for mean field, then references out to other chapters for particular instantiations (section 21.3)
April 8, 2013 at 7:05 pm |
Is there any chance you could work through an example, such as the beta-binomial model that you covered in an earlier posting?
April 8, 2013 at 7:16 pm |
Not until I understand it a bit better. As I said, I’m not very good with these exponential family derivations.