Lecture 8 and 9: Message passing on Factor Graphs

Machine Learning 4F13, Spring 2015

Carl Edward Rasmussen and Zoubin Ghahramani

CUED

http://mlg.eng.cam.ac.uk/teaching/4f13/

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 1 / 22 Key Concepts

• Factor graphs are a class of • A factor graph represents the product structure of a function, and contains factor nodes and variable nodes • We can compute marginals and conditionals efficiently by passing messages on the factor graph, this is called the sum-product algorithm (a.k.a. or factor-graph propagation) • We can apply this to the True Skill graph • But certain messages need to be approximated • One approximation method based on moment matching is called Expectation Propagation (EP)

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 2 / 22 Factor Graphs

Factor graphs are a type of probabilistic graphical model (others are directed graphs, a.k.a. Bayesian networks, and undirected graphs, a.k.a. Markov networks) Factor graphs allow to represent the product structure of a function. Example: consider the factorising :

p(v, w, x, y, z) = f1(v, w)f2(w, x)f3(x, y)f4(x, z)

A factor graph is a with two types of nodes:

• Factor node:  Variable node: • Edges represent the dependency of factors on variables.

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 3 / 22 Factor Graphs

p(v, w, x, y, z) = f1(v, w)f2(w, x)f3(x, y)f4(x, z)

• What are the marginal distributions of the individual variables? • What is p(w)? • How do we compute conditional distributions, e.g. p(w|y)?

For now, we will focus on tree-structured factor graphs.

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 4 / 22 Factor trees: separation (1)

p(w) = f1(v, w)f2(w, x)f3(x, y)f4(x, z) v x y z X X X X

• If v, x, y and z take K values each, we have ≈ 3K4 products and ≈ K4 sums, for each value of w, i.e. total O(K5). • Multiplication is distributive: ca + cb = c(a + b). The right hand side is more efficient!

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 5 / 22 Factor trees: separation (2)

p(w) = f1(v, w)f2(w, x)f3(x, y)f4(x, z) v x y z hX X X Xi h i = f1(v, w) · f2(w, x)f3(x, y)f4(x, z) v x y z X X X X

• In a tree, each node separates the graph into two parts. • Grouping terms, we go from sums of products to products of sums. • The complexity is now O(K4).

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 6 / 22 Factor trees: separation (3)

h i h i p(w) = f1(v, w) · f2(w, x)f3(x, y)f4(x, z) v x y z X X X X mf →w(w) m (w) 1 f2→w | {z } | {z } • Sums of products becomes products of sums of all messages from neighbouring factors to variable.

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 7 / 22 Messages: from factors to variables (1)

mf2→w(w)= f2(w, x)f3(x, y)f4(x, z) x y z X X X

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 8 / 22 Messages: from factors to variables (2)

mf2→w(w)= f2(w, x)f3(x, y)f4(x, z) x y z X X X h i = f2(w, x) · f3(x, y)f4(x, z) x y z X X X m (x) x→f2 | {z } • Factors only need to sum out all their local variables.

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 9 / 22 Messages: from variables to factors (1)

mx→f2 (x)= f3(x, y)f4(x, z) y z X X

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 10 / 22 Messages: from variables to factors (2)

mx→f2 (x)= f3(x, y)f4(x, z) y z Xh X i h i = f3(x, y) · f4(x, z) y z X X m (x) mf →x(x) f3→x 4 | {z } | {z } • Variables pass on the product of all incoming messages.

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 11 / 22 Factor graph marginalisation: summary

p(w) = f1(v, w)f2(w, x)f3(x, y)f4(x, z) v x y z Xh X X Xi h hh i h iii = f1(v, w) · f2(w, x) · f3(x, y) · f4(x, z) v x y z X X X X mf →w(w) m (x) mf →x(x) 1 f3→x 4

| {z } m |(x) {z } | {z x}→f2

m (w) | f2→w {z } | {z } • The complexity is reduced from O(K5) (naïve implementation) to O(K2).

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 12 / 22 The sum-product algorithm

Three update equations:

• Marginals are the product of all incoming messages from neighbour factors

p(t) = mf→t(t) f∈F Yt • Messages from factors sum out all variables except the receiving one

mf→t1 (t1) = ... f(t1, t2,..., tn) mti→f(ti) t t t i6= X2 X3 Xn Y1 • Messages from variables are the product of all incoming messages except the message from the receiving factor

mt→f(t) = mfj→t(t) fj∈YFt\{f} Messages are results of partial computations. Computations are localised.

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 13 / 22 The full TrueSkill graph

2 Prior factors: fi(wi) = N(wi; µ0, σ0)

“Game” factors:

hg(wIg , wJg , tg) = N(tg; wIg − wJg , 1)

(Ig and Jg are the players in game g)

Outcome factors:  kg(tg, yg) = δ yg − sign(tg)

We are interested in the marginal distributions of the skills wi.

• What shape do these distributions have? • We need to make some approximations. • We will also pretend the structure is a tree (ignore loops).

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 14 / 22 Expectation Propagation in the full TrueSkill graph

Iterate

(1) Update skill marginals.

(2) Compute skill to game messages.

(3) Compute game to performance messages.

(4) Approximate performance marginals.

(5) Compute performance to game messages.

(6) Compute game to skill messages.

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 15 / 22 Message passing for TrueSkill mτ=0 (w )= 1, mτ=0 (w )= 1, ∀ g, hg→wIg Ig hg→wJg Jg N qτ(w ) = f(w ) mτ (w ) ∼ N(µ σ2) i i hg→wi i i, i , g= Y1 τ τ q (wI ) q (wJ ) mτ (w )= g mτ (w )= g wI →hg Ig τ , wJ →hg Jg τ , g m (wI ) g m (wJ ) hg→wIg g hg→wJg g mτ (t )= h (t , w , w )mτ (w )mτ (w )dw dw , hg→tg g g g Ig Jg wIg →hg Ig wJg →hg Jg Ig Jg ZZ qτ+1(t ) = mτ (t )m (t ) g Approx hg→tg g kg→tg g , τ+1 τ+1 q (tg) m (tg)= , tg→hg mτ (t ) hg→tg g

τ+1 τ+1 τ m (wI )= hg(tg, wI , wJ )m (tg)m (wJ )dtgdwJ , hg→wIg g g g tg→hg wJg →hg g g ZZ τ+1 τ+1 τ m (wJ )= hg(tg, wJ , wJ )m (tg)m (wI )dtgdwI . hg→wJg g g g tg→hg wIg →hg g g ZZ Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 16 / 22 In a little more detail

At iteration τ messages m and marginals q are Gaussian, with means µ, standard deviations σ, variances v = σ2, precisions r = v−1 and natural means λ = rµ. Step 0 Initialise incoming skill messages:

rτ=0 = 0 hg→wi τ=0 m (wi) µτ=0 = 0 hg→wi hg→wi 

Step 1 Compute marginal skills:

rτ = r + rτ i 0 g hg→wi τ τ τ q (wi) λ = λ0 + λ i P g hg→wi Step 2 Compute skill to game messages:P

rτ = rτ − rτ wi→hg i hg→wi τ m (wi) λτ = λτ − λτ wi→hg wi→hg i hg→wi

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 17 / 22 Step 3 Game to performance messages:

vτ = 1 + vτ + vτ hg→tg wIg →hg wJg →hg τ m (tg) µτ = µτ − µτ hg→tg hg→tg Ig→hg Jg→hg 

Step 4 Compute marginal performances:

p(t ) ∝ N(µτ vτ ) y − (t) g hg→tg , hg→tg I sign τ+1 τ+1 τ+1 ' N(µ˜ g , v˜ g ) = q (tg)

We find the parameters of q by moment matching

µτ τ+1 τ hg→tg  v˜ = v 1 − Λ τ g hg→tg σ hg→tg τ+1 µτ q (tg) τ+1 τ τ hg→tg  µ˜ = µ + σ Ψ τ  g hg→tg hg→tg σ hg→tg  where we have defined Ψ(x) = N(x)/Φ(x) and Λ(x) = Ψ(x)(Ψ(x) + x).

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 18 / 22 Step 5 Performance to game message:

rτ+1 = r˜τ+1 − rτ tg→hg g hg→tg τ+1 m (tg) λτ+1 = λ˜ τ+1 − λτ tg→hg tg→hg g hg→tg 

Step 6 Game to skill message: For player 1 (the winner):

vτ+1 = 1 + vτ+1 + vτ hg→wIg tg→hg wJg →hg τ+1 τ+1 τ τ+1 m (wIg ) µ = µ + µ hg→wIg hg→wIg wJg →hg tg→hg  and for player 2 (the looser):

vτ+1 = 1 + vτ+1 + vτ hg→wJg tg→hg wIg →hg τ+1 τ+1 τ τ+1 m (wJg ) µ = µ − µ hg→wJg hg→wJg wIg →hg tg→hg 

Go back to Step 1 with τ := τ + 1 (or stop).

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 19 / 22 Moments of a truncated Gaussian density (1)

Consider the truncated Gaussian density function 1 p(t) = δy − sign(t)N(t; µ, σ2) where y ∈ {−1, 1} and δ(x) = 1 iff x = 0. Zt

original Gaussian truncated µ = µ renormalized σ2 = σ2 moment matched probability density

0

We want to approximate p(t) by a Gaussian density function q(t) with mean and variance equal to the first and second central moments of p(t). We need:

• First moment: E[t] = htip(t) 2 2 • Second central moment: V[t] = ht ip(t) − htip(t)

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 20 / 22 Moments of a truncated Gaussian density (2)

yµ  We have seen that the normalisation constant is Zt = Φ σ .

First moment. We take the derivative of Zt wrt. µ: ∂Z ∂ + + ∂ t = N(t; yµ, σ2)dt = N(t; yµ, σ2)dt ∂µ ∂µ 0 ∞ 0 ∞ ∂µ + Z Z + −2 2 −2 = yσ (t − yµ)N(t; yµ, σ )dt = yZtσ (t − yµ)p(t)dt Z0 ∞ Z− ∞ −2 −2 −2 = yZtσ ht − yµip(t) = yZtσ htip(t) − µZtσ ∞ where htip(t) is the expectation of t under p(t). We can also write: ∂Z ∂ yµ t = Φ  = yN(yµ; 0, σ2) ∂µ ∂µ σ

∂Zt Combining both expressions for ∂µ we obtain 2 yµ 2 N(yµ; 0, σ ) N( σ ; 0, 1) yµ htip(t) = yµ + σ yµ = yµ + σ yµ = yµ + σΨ Φ( σ ) Φ( σ ) σ 2 −1 yµ N(z;0,1) where use N(yµ; 0, σ ) = σ N( σ ; 0, 1) and define Ψ(z) = Φ(z) . Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 21 / 22 Moments of a truncated Gaussian density (3)

Second moment. We take the second derivative of Zt wrt. µ:

∂2Z ∂ + t = yσ−2(t − yµ)N(t yµ σ2) t 2 ; , d ∂µ ∂µ 0 ∞ yµZ = Φ h−σ−2 + σ−4(t − yµ)2i σ p(t)

We can also write ∂2Z ∂ t = yN(yµ; 0, σ2) = −σ−2yµN(yµ; 0, σ2) ∂µ2 ∂µ

Combining both we obtain yµ [t] = σ21 − Λ( ) V σ where we define Λ(z) = Ψ(z)Ψ(z) + z.

Rasmussen and Ghahramani Lecture 8 and 9: Message passing on Factor Graphs 22 / 22