The Junction Tree Algorithm
Total Page:16
File Type:pdf, Size:1020Kb
Probabilistic and Bayesian Machine Learning Day 3: The Junction Tree Algorithm Yee Whye Teh [email protected] Gatsby Computational Neuroscience Unit University College London http://www.gatsby.ucl.ac.uk/∼ywteh/teaching/probmodels Inference for HMMs y1 I y2 I y3 I I yτ u u u I I I I x1 x2 x3 xτ τ Y P (Y1:τ ;X1:τ ) = P (Y1)P (X1jY1) P (YtjYt−1)P (XtjYt) t=2 Note that we suppressed writing dependence on parameters θ for clarity. Inference and Learning Problems: • Filtering: P (YtjX1;:::;Xt) • Prediction: P (YtjX1;:::;Xt−1) • Smoothing: P (YtjX1;:::;Xτ ) and P (Yt;Yt+1jX1;:::;Xτ ) • Likelihood: P (X1;:::;Xτ ) ML • Learning: θ = argmaxθ P (X1;:::;Xτ jθ) Likelihood of HMMs y1 I y2 I y3 I I yτ u u u I I I I x1 x2 x3 xτ The likelihood of a HMM looks to be an intractable sum of an exponential number of terms: X P (X1:τ ) = P (Y1:τ ;X1:τ ) Y1:τ Fortunately, there is a forward recursion allowing us to compute the likelihood efficiently. For t = 1, α1(Y1) := P (X1;Y1) = P (Y1)P (X1jY1) For t = 2; : : : ; τ, X αt(Yt) := P (X1:t;Yt) = αt−1(Yt−1)P (YtjYt−1)P (XtjYt) Yt−1 The likelihood is then, X X P (X1:τ ) = P (X1:τ ;Yτ ) = ατ (Yτ ) Yτ Yτ Filtering and Prediction of HMMs y1 I y2 I y3 I I yτ u u u I I I I x1 x2 x3 xτ 2 For discrete latent variables Yt’s with K states, computational cost is O(τK ) instead of O(Kτ ). The forward messages αt(y) also give us the filtering probabilities once normalized: K X P (YtjX1:t) = αt(Yt)= αt(m) m=1 Numerical stability: in practice it is important to normalize the forward messages after each step, as each P (X1:t;Yt) can become exponentially small in t: X 0 Zt = αt(Yt) αt(Yt) := αt(Yt)=Zt Yt The log likelihood is then the sum of the normalization constants: τ X log P (X1:τ ) = log Zt t=1 Smoothing of HMMs y1 I y2 I y3 I I yτ u u u I I I I x1 x2 x3 xτ For smoothing, we also require to incorporate information from observations at later times. This requires computation of backward messages which carry this information. For t = τ, βτ (Yτ ) := 1 For t = τ − 1;:::; 1, X βt(Yt) = P (Xt+1:τ jYt) = P (Yt+1jYt)P (Xt+1jYt+1)βt+1(Yt+1) Yt+1 The marginal probabilities are then: P (YtjX1:τ ) / P (X1:t;Yt)P (Xt+1:τ jYt) = αt(Yt)βt(Yt) And the pairwise probabilities are: P (Yt;Yt+1jX1:τ ) / αt(Yt)P (Yt+1jYt)P (Xt+1jYt+1)βt+1(Yt+1) The algorithm is variously known as the forward-backward algorithm, the sum-product algo- rithm, or belief propagation. Distributive Law The belief propagation algorithm can be understood as implementing the distributive law of multiplication over addition. In particular the backward messages are given by: X P (X1:τ ) = P (Y1:τ ;X1:τ ) Y1:τ τ X Y = P (Y1)P (X1jY1) P (YtjYt−1)P (XtjYt) Y1:τ t=2 X X X X = P (Y1)P (X1jY1) P (Y2jY1)P (X2jY2) ··· P (Yτ jYτ−1)P (Xτ jYτ ) Y1 Y2 Y3 Yτ Similarly the forward messages can be obtained by the reversed way of distributing the sums into the products. Other belief propagation algorithms can be obtained with different distributive laws: • Max-product: multiplications distribute over maximizations as well. The max-product al- gorithm computes the most likely state trajectory: MAP Y1:τ = argmax P (Y1:τ ;X1:τ ) Y1:τ • Max-sum: Take logarithms before applying max-product gives the max-sum. This is more efficient (addition computations are faster than multiplications on CPUs) and numerically stable. Undirected Graphs and Partition Functions τ X Y P (X1:τ ) = P (Y1)P (X1jY1) P (YtjYt−1)P (XtjYt) Y1:τ t=2 Define factors for an undirected graphical model over the latent variables Y1:τ only: ft(Yt) = P (XtjYt) ft;t+1(Yt;Yt+1) = P (Yt+1jYt) The joint distribution given by the undirected graph is: 1 Y P U (Y ) = f (Y )f (Y ;Y ) 1:τ Z t t t;t+1 t t+1 t The normalization constant (also called the partition function) of the distribution is precisely: X Y P (X1:τ ) = Z = ft(Yt)ft;t+1(Yt;Yt+1) Y1:τ t While the smoothed marginals are just the marginal distributions of P U : P Q 0 f 0(Y 0)f 0 0 (Y 0;Y 0 ) P (Yt;X1:τ ) Y1:t−1;t+1:τ t t t t ;t +1 t t +1 U P (YtjX1:τ ) = = = P (Yt) P (X1:τ ) Z U P (Yt;Yt+1jX1:τ ) = P (Yt;Yt+1) Belief Propagation on Undirected Trees Undirected tree parameterization for joint distribution: 1 Y i j p(Y) = f (Y ;Y ) Z (ij) i j edges (ij) Define Ti!j to be the subtree containing i if j is removed. Define messages: X Y Mi!j(Yj) = f(ij)(Yi;Yj) f(i0j0)(Yi0;Yj0) Y 0 0 Ti!j edges (i j ) in Ti!j Then the messages can be recursively computed as follows: X Y Mi!j(Yj) = f(ij)(Yi;Yj) Mk!i(Yi) Yi k2ne(i)nj and: Y p(Yi) / Mk!i(Yi) k2ne(i) Y Y p(Yi;Yj) / f(ij)(Yi;Yj) Mk!i(Yi) Ml!j(Yj) k2ne(i)nj l2ne(j)ni Recap on Graphical Models X Y { i⊥⊥ i|Vi} p(X) = p(X X ) p(X) : p(Xi, Yi i) { i i| pa(i) } { |V =p(Xi i)p(Yi i) ! |V |V } Graph Transformations Our general approach to inference in arbitrary graphical models is to transform these graph- ical models to ones belonging to an easy-to-handle class (specifically, junction or join trees). Suppose we are interested in inference in a distribution p(Y) representable in a graphical model G. • In transforming G to an easy-to-handle G0, we need to ensure that p(Y) is representable in G0 too. • This can be ensured by making sure that every step of the graph transformation only removes conditional independencies, never adds them. • This guarantees that the family of distributions can only grow at each step, and p(Y) will be in the family of distributions represented by G0. • Thus inference algorithms working on G0 will work for p(Y) too. The Junction Tree Algorithm A B A B A B E C E C E C D D undirected D graph directed acyclic graph factor graph AB AB E E A B BE BE BC BC E E E C EC EC CD CD D E E chordal or triangulated message passing junction tree undirected graph Directed Acyclic Graphs to Factor Graphs A B A B E C E C D D Factors are simply the conditional distributions in the DAG. Y p(Y) = p(YijYpa(i)) i Y = fi(YCi) i where Ci = i [ pa(i) and fi(YCi) = p(YijYpa(i)). Marginal distribution on roots p(Yr) absorbed into an adjacent factor. Entering Evidence in Factor Graphs A B A B E C E C D D Often times we are interested in inferring posterior distributions given observed evidence (e.g. D = wet and C = rain). This can be achieved by adding factors with just one adjacent node, with ( 1 if D =wet; fD(D) = 0 otherwise. ( 1 if C =rain; fC(C) = 0 otherwise. Factor Graphs to Undirected Graphs A B A B E C E C D D Just need to make sure that every factor is contained in some maximal clique of the undi- rected graph. 1 Y p(Y) = f (Y ) Z i Ci i We can make sure of this simply by converting each factor into a clique, and absorbing fi(YCi) into the factor of some maximal clique containing it. The transformation DAG ) undirected graph is called moralization—we simply “marry” all parents of each node by adding edges connecting them, then drop all arrows on edges. Triangulation of Undirected Graphs A B A B E C E C D D Message passing—messages contain information from other parts of the graph, and this information is propagated around the graph during inference. If loops (cycles) are present in the graph, message passing can lead to overconfidence due to double counting of information, and to oscillatory (non-convergent) behaviour1. To prevent this overconfident and oscillatory behaviour, we need to make sure that different channels of communication coordinate with each other to prevent double counting. Triangulation: add edges to the graph so that every loop of size > 4 has at least one chord. Note recursive nature: adding edges often creates new loops; we need to make sure new loops of length > 4 have chords too. 1This is called loopy belief propagation, and we will see in the second half of the course that this is an important class of approximate inference algorithms. Triangulation of Undirected Graphs A B A B E C E C D D Triangulation: add edges to the graph so that every loop of size > 4 has at least one chord. Remember that adding edges always removes conditional independencies and enlarges the family of distributions. There are many ways to add chords; in general finding the best triangulation is NP-complete.