<<

10-708: Probabilistic Graphical Models 10-708, Spring 2012

13 : Variational Inference: Loopy Belief Propagation and Mean Field

Lecturer: Eric P. Xing Scribes: Peter Schulam and William Wang

1 Introduction

Inference problems involve answering a query that concerns the likelihood of observed data. For example, to P answer the query on a marginal p(xA), we can perform the marginalization operation to derive C/A p(x). Or, for queries concern the conditionals, such as p(xA|xB), we can first compute the joint, and divide by the marginals p(xB). Sometimes to answer a query, we might also need to compute the mode of density xˆ = arg maxx∈X m p(x). So far, in the class, we have covered the exact inference problem. To perform exact inference, we know that brute force search might be too inefficient for large graphs with complex structures, so a family of message passing such as forward-backward, sum-product, max-product, and junction , was introduced. However, although we know that these message-passing based exact inference work well for tree- structured graphical models, it was also shown in the class that they might not yield consistent results for loopy graphs, or, the convergence might not be guaranteed. Also, for complex graphical models such as the Ising model, we cannot run exact inference algorithm such as the , because it is computationally intractable. In this lecture, we look at two variational inference algorithms: loopy belief propagation (yww) and mean field approximation (pschulam).

2 Loopy Belief Propagation

The general idea for loopy belief propagation is that even though we know graph contains loops and the messages might circulate indefinitely, we still let it run anyway and hope for the best. In this section, we first review the basic belief propagation algorithms. Then, we discuss an experimental study by Murphy et al. (1999), and show some empirical results on the effects of loopy belief propagation. Most importantly, we start from the notion of KL divergence, and show how to explain the LBP algorithm from the perspective of minimizing the Bethe free energy.

2.1 Belief Propagation: a Quick Review

The basic idea of belief propagation is very simple: to update the belief on a node, we just need to calculate the doubleton potentials from its neighboring nodes, and multiply with the target node’s singleton potential. To give a more concrete example, let’s consider the example on Figure 1. On the part (a) of the figure, we see that in order to compute the message Mi→j(xj), we will need to calculate the message from all the neighboring nodes xk to xi. Then, multiply with the singletons and doubletons concerning xi and xi, xj: X Y Mi→j(xj) ∝ Φij(xi, xj)Φi(xi) Mk→i(xi) (1)

xi k

1 2 13 : Variational Inference: Loopy Belief Propagation and Mean Field

Figure 1: Belief propagation: an example.

Here the doubleton potential Φij(xi, xj) is also called compatibilities, and is used to model the interactions of the two nodes, whereas the singleton potential Φi(xi) is also called the external evidence. On the right-hand side (part b), we can simple update the belief of xi using the similar formulation: Y bi(xi) ∝ Φi(xi) Mk(xk) (2) k

Similarly, for factor graphs, we can also have the notion of “messages” and update the belief of node xi by multiplying its factor and the messages coming from neighboring nodes: Y bi(xi) ∝ fi(xi) ma→i(xi) (3) a∈N(i)

If we want to calculate message from node Xa to node xi, we should sum up all the products: X Y ma→i(xi) ∝ fa(Xa) mj→a(xj) (4)

Xa\xi j∈N(a)\i

In the class, we know that running BP on trees always converges to the exact solution. However, it is not always the case for loopy graphs. The problem is that when the message is sent into a loop structure, it might circulate indefinitely, so it does not guarantee the convergence or it might converge to the wrong solution.

2.2 Loopy Belief Propagation Algorithm

The idea of loopy belief propagation algorithm is to use a fixed point iteration procedure to minimize the Bethe free energy. Basically, if the convergence criteria is not met, we can update the messages and the believes: Y bi(xi) ∝ ma→i(xi) (5) a∈N(i) Y ba(Xa) ∝ fa(Xa) mi→a(xi) (6) i∈N(a) new Y mi→a(xi) = mc→i(xi) (7) c∈N(i)\a new X Y ma→i(xi) = fa(Xa) mj→a(xj) (8)

Xa\xi j∈N(a)\i 13 : Variational Inference: Loopy Belief Propagation and Mean Field 3

Therefore, we know that stationary properties are guaranteed when it converges. However, the big problem here is that the convergence is not guaranteed, and the reason is intuitive: when BP algorithm is running on graphs that include loops, the messages might be circulating in the loops forever. Interestingly, Murphy et al. (1999 UAI) has studied the empirical behaviors of the loopy belief propagation algorithm, and found that LBP can still achieve good approximations:

• The program is stopped after a fixed number of iterations. • Stop when there is no significant difference in belief updates. • When the solution converges, it is usually a good approximation.

And this is probably the reason why LBP is still a very popular inference algorithm, even though convergence might be guaranteed. Also, it was mentioned in class that, in order to test the empirical performance of an approximate inference algorithm on large intractable problems, one can always start simple by testing on a small example of the problem (e.g. a 20 x 20 nodes Ising model).

2.3 Understanding LBP: a FBethe Minimization Perspective

To understand the LBP algorithm, let’s first define the true distribution P as: 1 Y P (X) = f (X ) (9) Z a a fa∈F where Z is the partition function, and we are interested in the product of factors. Since this is often intractable, we can approximate the distribution P with Q. To do this, we can utilize the KL-divergence method: X Q(X) KL(Q||P ) = Q(X) log (10) P (X) X Note that the KL divergence is asymmetric. The value from KL should be non-negative, and has the minimum value when P = Q. KL is very useful in our problem, because it means that we can know use KL to approximate Q. To do this, we can compute KL(Q||P ) without performing inference: X Q(X) KL(Q||P ) = Q(X) log (11) P (X) X X X = Q(X) log Q(X) − Q(X) log P (X) (12) X X

= −HQ(X) − EQ log P (X) (13) If we replace P (X) with our earlier definition on the true probability, we can get:  1 Y  KL(Q||P ) = −H (X) − E log f (X ) (14) Q Q Z a a fa∈F 1 X = −H (X) − log − E log f (X ) (15) Q Z Q a a fa∈F Note that if we re-arrange the terms on the right side of the equation, we can get X KL(Q||P ) = −HQ(X) − EQ log fa(Xa) + log Z (16)

fa∈F 4 13 : Variational Inference: Loopy Belief Propagation and Mean Field

And the Physicists define the first two terms on the right side of equation as “(Gibbs) free energy” F (P,Q). So, now, our goal can be boiled down to compute the F (P,Q). In order to do this, we know that P E log f (X ) can be computed by summing up all the marginals, where as computing H (X) fa∈F Q a a Q is a much harder task that needs to sum over all possible values, which is very expensive. However, we can always approximate F (P,Q) by computing Fˆ(P,Q). Before we show how to approximate the Gibbs free energy, let’s first consider the case with tree graphical models in Fig. 2 : Here, we know the probability can

Figure 2: Calculating the tree energy: an example. Q Q 1−di be written as: b(x) = a ba(xa) i bi(xi) , and the Htree and Ftree can be written as: X X X X Htree = − ba(xa) log ba(xa) + (di − 1) bi(xi) log bi(xi) (17)

a xa i xi X X ba(xa) X X Ftree = ba(xa) log + (1 − di) bi(xi) log bi(xi) (18) fa(xa) a xa i xi

= F12 + F23 + ··· + F67 + F78 − F1 − F5 − F2 − F6 − F3 − F7 (19) It can be seen that from the above derivation, we only need to sum over the singletons and doubletons, which is easy to compute. Similarly, we can also use the above idea to approximate the Gibbs free energy. For example, in a general graph, such as Fig. 3, we also have:

Figure 3: Calculting the loopy graph Bethe energy: an example. X X X X HBethe = − ba(xa) log ba(xa) + (di − 1) bi(xi) log bi(xi) (20)

a xa i xi X X ba(xa) X X FBethe = ba(xa) log + (1 − di) bi(xi) log bi(xi) = −hfa(xa)i − HBethe (21) fa(xa) a xa i xi

= F12 + F23 + ··· + F67 + F78 − F1 − F5 − 2F2 − 2F6 − F8 (22) 13 : Variational Inference: Loopy Belief Propagation and Mean Field 5

So, this is called the Bethe approximation of the Gibbs free energy. The idea is simple: we just need to sum over the singletons and the doubletons to derive the entropy. However, we need to notice that this approximation might not be well connected to the Gibbs free energy. Now, to minimize the Bethe free energy, we can write out the objective with Lagrangian dual form:

X X X X X n X o ` = FBethe + γi{1 − bi(xi)} + λai(xi) bi(xi) − ba(Xa) (23)

i xi a i∈N(a) xi Xa\xi

Now to solve this, we need to take the partial derivatives and set them into zeros ( ∂` = 0 and ∂` = 0). ∂bi(xi) ∂ba(Xa) Then, we have:

 1 X  bi(xi) ∝ exp λai(xi) (24) di − 1 a∈N(i)  X  ba(Xa) ∝ exp − Ea(Xa) + λai(xi) (25) i∈N(a)

Interestingly, if we let λ (x ) = log(m (x )) = log Q m (x ), and use b (x ) = P b (X ), ai i i→a i b∈N(i)6=a b→i i a→i i Xa\xi a a then we can obtain exactly the same BP formulations in equations 3 and 4. This is very attractive, because we have shown how to derive the message passing algorithm from the perspective of minimizing the Bethe free energy. So, in general, the variational methods can be summarize as:

∗ n o q = arg min FBethe(p, q) (26) q∈S where q is a now a tractable problem. Note that here we do not want to optimize q(X) directly. Instead, we want to focus on a relaxed feasible set and approximate the objective: nD E o b∗ = arg min E F (b) (27) b∈Mo b where b covers the edge potentials (doubletons) and the node potentials (singletons). To solve for b∗, we typically use a fixed point iteration algorithm.

3 Mean Field Approximation

Recall that the purpose of approximate inference methods is to allow us to compute the posterior distribution over a model’s latent variables even when the posterior involves an intractable integral or summation. As a motivating example, we will look at a Bayesian mixture of Gaussians with known variance b2. To review, a mixture of K Gaussians has the following generative story:

• θ ∼ Dir(α)

2 • µk ∼ N (0, a ) for k ∈ {1,...,K}

• For i ∈ {1, . . . , n}

– zi ∼ Mult(θ) 2 – xi ∼ N (µzi , b ) 6 13 : Variational Inference: Loopy Belief Propagation and Mean Field

Suppose that we wanted to compute a posterior distribution over the cluster assignments zi and cluster mean vectors µk. This would require us to compute the following quantity where µ = {µ1, . . . , µK }, z = {z1, . . . , zn}, and x = {x1, . . . , xn}

QK p(µ ) Qn p(z )p(x |z , µ) p(µ, z|x) = k=1 k i=1 i i i (28) R P QK Qn µ z k=1 p(µk) i=1 p(zi)p(xi|zi, µ)

We can easily compute the numerator, but the denominator will be intractable because it involves a sum- mation over all configuration of the latent cluster variables z. If there are K clusters, then the number of configurations that we would need to sum over is Kn. This is difficult by itself, and, in order to compute the denominator, we also need compute the integral over all mean vectors µ. In the above posterior distribution, the denominator is difficult to compute because the latent variables are not easily factored. Note, in particular, that they are coupled in the conditional density of a particular data point xi. In general, when the latent variables are coupled, we must sum an exponentially large number of terms in order to compute the normalizing quantity in a posterior distribution. If, however, the latent variables could be easily factored, then we might be able to more easily compute the normalizing term in the posterior distribution. Broadly, mean field variational inference is a technique used to design a new family of distributions over the latent variables that do factorize well, and can then be used to approximate the posterior distribution over the Gaussian mixture model parameters and latent variables shown above. In symbols, the mean field approximation assumes that the variational distribution over the latent variables factorizes

m Y q(z1, . . . , zm) = q(zi; νi) (29) i=1

More generally, we do not need to assume that the joint distribution over the latent variables factorizes into a separate term for each variable. We can include more broad families of variational distributions by instead assuming that the joint factorizes into independent distributions over clusters of the latent variables:

Y q(z1, . . . , zm) = q(Ci; νi) (30)

Ci∈C

Where C is some set of disjoint sets of the latent variables.

3.1 Variational Inference Objective Functions

Since we are approximating a distribution with our variational distribution q, a natural way to measure the quality of our approximation is using the Kullback-Leibler (KL) divergence between the true density p and our approximation q:

X p(x) KL(pkq) = p(x) log (31) q(x) x 13 : Variational Inference: Loopy Belief Propagation and Mean Field 7

This metric, however, is a problem since it requires pointwise evaluation of p(x), which is the problem we are trying to solve in the first place. An alternative is to reverse the directionality of the KL divergence:

X q(x) KL(qkp) = q(x) log (32) p(x) x

Assuming that our approximation q(x) is tractable to compute, this metric is a slight improvement, but still involves evaluating p(x) in the denominator of the log. Note, however, that the unnormalized measurep ˜(x) can be written as p(x)Z where Z is the normalizing factor of the distribution. When p(x) is a posterior p(x|D), then the normalizing constant Z is p(D). Using this fact, we can define a new objective function J(q):

J(q) = KL(qkp˜) (33) X q(x) = q(x) log (34) p˜(x) x X q(x) = q(x) log (35) p(x)Z x X q(x) X = q(x) log − q(x) log Z (36) p(x) x x X q(x) = q(x) log − log Z (37) p(x) x (38)

Since − log Z = − log p(D) is a constant, minimizing J(q) = KL(qkp˜) is equivalent to minimizing an up- per bound on the negative log likelihood of the evidence by minimizing the KL divergence between our approximation and the true distribution p(x). An alternative objective is to maximize −J(q), which is known as the energy functional.

3.2 Interpretations of the Objective Function

We can rewrite our objective function J(q) as

J(q) = Eq[log q(x)] + Eq[− logp ˜(x)] (39) = −H(q) + Eq[E(x)] (40)

Where E(x) = − logp ˜(x) is the energy. Intuitively, since we are minimizing J(q), we can see by breaking the objective function down in this way that minimizing our objective is attempting to do two things. First, we want to minimize the negative entropy (or increase the entropy), which, as we know from the maximum entropy principle, is a good way to measure how well a distribution will generalize. That is, we do not want to make unwarranted assumptions about the distribution, and should, in general, always seek to choose the distribution that maximizes entropy. Second, we want to minimize the expected energy Eq[− logp ˜(x)]. 8 13 : Variational Inference: Loopy Belief Propagation and Mean Field

Recall that the energy is lower when probability is high, so we would like to minimize energy if we want to maximize likelihood. The second term, at an intuitive level, is making sure that our approximate distribution puts more mass on x with low energy according top ˜(x), and less mass on x with high energy. Another interpretation of the objective function J(q) is

J(q) = Eq[log q(x) − log p(x)p(D|x)] (41) = Eq[log q(x) − log p(x) − log p(D|x)] (42) = Eq[− log p(D|x)] + KL(qkp) (43)

We can again see that breaking down the objective function can give us an intuitive feel for what we are minimizing. We see that we are minimizing the expected negative log likelihood of the data conditioned on x, which prefers distributions that put more probability mass on x that increase the likelihood of our observed data. In addition, there is a term that penalizes distributions q that are too far from p.

3.3 Optimizing the Variational Distribution

Now that we understand the objective functions that we want to minimize (maximize in the case of the energy functional), we can address the issue of actually finding the variational distribution that minimizes (or maximizes) the objective function. In what follows, we use the energy functional objective function:

X p˜(x) −J(q) = q(x) log (44) q(x) x

Recall that we wish to minimize this function. We will also use the most simple approximating distribution that assumes that the joint density over all hidden variables x can factor completely. That is

m Y q(x1, . . . , xm) = qi(xi; νi) (45) i=1

Our strategy will be to use coordinate ascent to maximize the −J(q) with respect to each qi. By deriving results that optimize the energy functional with respect to qi, it is relatively straight forward to extend the coordinate ascent updates to optimize the parameters νi for each local distribution qi.

Let us first view −J(q) as a function of one of the local distributions qi, which we will write as −J(qi). We can then rewrite the objective function 13 : Variational Inference: Loopy Belief Propagation and Mean Field 9

X p˜(x) −J(q ) = q(x) log (46) i q(x) x m !  m  X Y X = qi(xi) logp ˜(x) − log qj(xj) (47) x i=1 j=1     X X Y X = qi(xi) qk(xk) logp ˜(x) − log qj(xj) − log qi(xi) (48) xi x−i k6=i j6=i   X X Y X X Y X = qi(xi) qk(xk) logp ˜(x) − qi(xi) qk(xk)  log qj(xj) + log qi(xi) (49) xi x−i k6=i xi x−i k6=i j6=i X X Y X X Y ∝ qi(xi) qk(xk) logp ˜(x) − qi(xi) qk(xk) log qi(xi) (50)

xi x−i k6=i xi x−i k6=i X X Y X = qi(xi) qk(xk) logp ˜(x) − qi(xi) log qi(xi) (51)

xi x−i k6=i xi X X = qi(xi)E−qi [logp ˜(x)] − qi(xi) log qi(xi) (52) xi xi

With this modified form, we can now define E−qi [logp ˜(x)] to be the log of some function of xi:

log fi(xi) = E−qi [logp ˜(x)] (53) which allows us to rewrite the final expression in our derivation above as:

X X qi(xi) log fi(xi) − qi(xi) log qi(xi) = −KL(qikfi) (54)

xi xi

We then maximize our objective −J(qi) by minimizing KL(qikfi), which is clearly done by setting qi(xi) = fi(xi). Thus

log fi(xi) = E−qi [logp ˜(x)] (55)

fi(xi) = exp (E−qi [logp ˜(x)]) (56)

Therefore, the distribution for each qi that maximizes our objective function is

1 qi(xi) = exp (E−qi [logp ˜(x)]) (57) Zi where Zi is some normalizing constant to ensure that qi is a proper distribution. From this we can see that the approximate distribution over a particular hidden variable xi depends on the mean values of the rest of 10 13 : Variational Inference: Loopy Belief Propagation and Mean Field

the hidden variables. In the expectation, we can drop all terms that do not involve xi, which will remove the means of all variables that are not neighbors of xi. We see then that the distribution for a variable xi depends on the mean values of its neighbors. This is known as the mean field, which is where the name mean field approximation comes from.