Mean-Field Variational Inference Made Easy

by

I had the hardest time trying to understand variational inference. All of the presentations I’ve seen (MacKay, Bishop, Wikipedia, Gelman’s draft for the third edition of Bayesian Data Analysis) are deeply tied up with the details of a particular model being fit. I wanted to see the algorithm and get the big picture before being overwhelmed with multivariate exponential family gymnastics.

Bayesian Posterior Inference

In the Bayesian setting (see my earlier post, What is Bayesian Inference?), we have a joint probability model p(y,\theta) for data y and parameters \theta, usually factored as the product of a likelihood and prior term, p(y,\theta) = p(y|\theta) p(\theta). Given some observed data y, Bayesian predictive inference is based on the posterior density p(\theta|y) \propto p(\theta, y) of the the unknown parameter vector \theta given observed data vector y. Thus we need to be able to estimate the posterior density p(\theta|y) to carry out Bayesian inference. Note that the posterior is a whole density function—we’re not just after a point estimate as in maximum likelihood estimation.

Mean-Field Approximation

Variational inference approximates the Bayesian posterior density p(\theta|y) with a (simpler) density g(\theta|\phi) parameterized by some new parameters \phi. The mean-field form of variational inference factors the approximating density g by component of \theta = \theta_1,\ldots,\theta_J, as

g(\theta|\phi) = \prod_{j=1}^J g_j(\theta_j|\phi_j).

I’m going to put off actually defining the terms g_j until we see how they’re used in the variational inference algorithm.

What Variational Inference Does

The variational inference algorithm finds the value \phi^* for the parameters \phi of the approximation which minimizes the Kullback-Leibler divergence of g(\theta|\phi) from p(\theta|y),

\phi^* = \mbox{arg min}_{\phi} \ \mbox{KL}[ g(\theta|\phi) \ || \ p(\theta|y) ].

The key idea here is that variational inference reduces posterior estimation to an optimization problem. Optimization is typically much faster than approaches to posterior estimation such as Markov chain Monte Carlo (MCMC).

The main disadvantage of variational inference is that the posterior is only approximated (though as MacKay points out, just about any approximation is better than a delta function at a point estimate!). In particular, variational methods systematically underestimate posterior variance because of the direction of the KL divergence that is minimized. Expectation propagation (EP) also converts posterior fitting to optimization of KL divergence, but EP uses the opposite direction of KL divergence, which leads to overestimation of posterior variance.

Variational Inference Algorithm

Given the Bayesian model p(y,\theta), observed data y, and functional terms g_j making up the approximation of the posterior p(\theta|y), the variational inference algorithm is:

  • \phi \leftarrow \mbox{random legal initialization}
  • \mbox{repeat}
    • \phi_{\mbox{\footnotesize old}} \leftarrow \phi
    • \mbox{for } j \mbox{ in } 1:J
      • \mbox{{} \ \ set } \phi_j \mbox{ such that } g_j(\theta_j|\phi_j) = \mathbb{E}_{g_{-j}}[\log p(\theta|y)].
  • \mbox{until } ||\phi - \phi_{\mbox{\footnotesize old}}|| < \epsilon

The inner expectation is a function of \theta_j returning a single non-negative value, defined by

\mathbb{E}_{g_{-j}}[\log p(\theta|y)]

\begin{array}{l}  \mbox{ } = \int_{\theta_1} \ldots \int_{\theta_{j-1}} \int_{\theta_{j+1}} \ldots \int_{\theta_J}  \\[8pt] \mbox{ } \hspace*{0.4in}  g(\theta_1|\phi_1) \times \cdots \times g(\theta_{j-1}|\phi_{j-1}) \times  g(\theta_{j+1}|\phi_{j+1}) \times \cdots \times g(\theta_J|\phi_J)  \times  \log p(\theta|y)  \\[8pt] \mbox{ } \hspace*{0.2in}  d\theta_J \cdots d\theta_{j+1} \ d\theta_{j-1} \cdots d\theta_1  \end{array}

Despite the suggestive factorization of g and the coordinate-wise nature of the algorithm, variational inference does not simply approximate the posterior marginals p(\theta_j|y) independently.

Defining the Approximating Densities

The trick is to choose the approximating factors so that the we can compute parameter values \phi_j such that g_j(\theta_j|\phi_j) = \mathbb{E}_{g_{-j}}[\log p(\theta|y)]. Finding such approximating terms g_j(\theta_j|\phi_j) given a posterior p(\theta|y) is an art form unto itself. It’s much easier for models with conjugate priors. Bishop or MacKay’s books and the Wikipedia present calculations for a wide range of exponential-family models.

What if My Model is not Conjugate?

Unfortunately, I almost never work with conjugate priors (and even if I did, I’m not much of a hand at exponential-family algebra). Therefore, the following paper just got bumped to the top of my must understand queue:

It’s great having Dave down the hall on sabbatical this year — one couldn’t ask for a better stand in for Matt Hoffman. They are both insanely good at on-the-fly explanations at the blackboard (I love that we still have real chalk and high quality boards).

7 Responses to “Mean-Field Variational Inference Made Easy”

  1. Eric Says:

    Thanks, Bob. Nice overview. Would it make any sense to try to minimize some symmetric divergence, such as the Jensen-Shannon divergence?

    • Bob Carpenter Says:

      It certainly makes sense, but the issue is whether you can come up with some computable way to do it. The trick to variational inference is that the hairy integral involved in the componentwise hill climbing can be solved for conjugate priors and approximated elsewhere.

  2. michael Says:

    I am actually not a fun of any talor expansion based approximation methods including the reference above. They simly break the lower bound condition (except first order approximation).
    Actually, there is another paper written by some Japanese researcher that I cannot recall the title. They show that mean field (including non conjugate models) can see from the dual problem, is to minimize bregman divergence block coordinate wise.

  3. Aki Vehtari Says:

    With power-EP method it is possible to use alpha-divergence which includes both KL divergences and symmetric Hellinger distance as special cases.

  4. brendan o'connor (@brendan642) Says:

    i thought the presentation in the Koller and Friedman textbook was helpful for variational inference — a bit less tied up with a particular model if i remember right. The Murphy textbook also has a general derivation for mean field, then references out to other chapters for particular instantiations (section 21.3)

  5. Rob Says:

    Is there any chance you could work through an example, such as the beta-binomial model that you covered in an earlier posting?

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s


Follow

Get every new post delivered to your Inbox.

Join 817 other followers