15 | Variational inference
15.1 Foundations
Variational inference is a statistical inference framework for probabilistic models that comprise unobserv- able random variables. Its general starting point is a joint probability density function over observable random variables y and unobservable random variables ϑ,
p(y, ϑ) = p(ϑ)p(y|ϑ), (15.1) where p(ϑ) is usually referred to as the prior density and p(y|ϑ) as the likelihood. Given an observed value of y, the first aim of variational inference is to determine the conditional density of ϑ given the observed value of y, referred to as the posterior density. The second aim of variational inference is to evaluate the marginal density of the observed data, or, equivalently, its logarithm Z ln p(y) = ln p(y, ϑ) dϑ. (15.2)
Eq. (15.2) is commonly referred to as the log marginal likelihood or log model evidence. The log model evidence allows for comparing different models in their plausibility to explain observed data. In the variational inference framework it is not the log model evidence itself which is evaluated, but rather a lower bound approximation to it. This is due to the fact that if a model comprises many unobservable variables ϑ the integration of the right-hand side of (15.2) can become analytically burdensome or even intractable. To nevertheless achieve its two aims, variational inference in effect replaces an integration problem with an optimization problem. To this end, variational inference exploits a set of information theoretic quantities as introduced in Chapter 11 and below. Specifically, the following log model evidence decomposition forms the core of the variational inference approach (Figure 15.1):
ln p(y) = F (q(ϑ)) + KL(q(ϑ)kp(ϑ|y)). (15.3)
In eq. (15.3), q(ϑ) denotes an arbitrary probability density over the unobservable variables which is used as an approximation of the posterior density p(ϑ|y). In the following, q(ϑ) is referred to as variational density. In words, (15.3) states that for an arbitrary variational density q(ϑ), the log model evidence comprises the sum of two information theoretic quantities: the so-called variational free energy, defined as Z p(y, ϑ) F (q(ϑ)) := q(ϑ) ln dϑ (15.4) q(ϑ) and the KL divergence between the true posterior density p(ϑ|y) and the variational density q(ϑ),
Z q(ϑ) KL(q(ϑ)||p(ϑ|y)) = q(ϑ) ln dϑ. (15.5) p(ϑ|y)
Based on these definitions, it is straightforward to show the validity of the log model evidence decompo- sition: Foundations 153
Figure 15.1. Visualization of the log model evidence decomposition that lies at the heart of the variational inference approach. The upper vertical bar represents the log model evidence, which is a function of the probabilistic model p(y, ϑ) and is constant for any observation of y. As shown in the main text, the log model evidence can readily be rewritten into the sum of the variational free energy term F (q(ϑ)) and a KL divergence term KL(q(ϑ)||p(ϑ|y)), if one introduces an arbitrary variational density over the unobservable variables ϑ. Maximizing the variational free energy hence minimizes the KL divergence between the variational density q(ϑ) and the true posterior density p(ϑ|y) and renders the variational free energy a better approximation of the log model evidence. Equivalently, minimizing the KL divergence between the variational density q(ϑ) and the true posterior density p(ϑ|y) maximizes the free energy and also renders it a tighter approximation to the log model evidence ln p(y).
Proof of (15.3)
By definition, we have
Z p(y, ϑ) F (q(ϑ)) = q(ϑ) ln dϑ q(ϑ) Z p(y)p(ϑ|y) = q(ϑ) ln dϑ q(ϑ) Z Z p(ϑ|y) = q(ϑ) ln p(y) dϑ + q(ϑ) ln dϑ q(ϑ) (15.6) Z Z p(ϑ|y) = ln p(y) q(ϑ) dϑ + q(ϑ) ln dϑ q(ϑ) Z q(ϑ) = ln p(y) − q(ϑ) ln dϑ p(ϑ|y) = ln p(y) − KL(q(ϑ)||p(ϑ|y)), from which eq. (15.3) follows immediately. 2 The log model evidence decomposition (15.3) can be used to achieve the aims of variational inference as follows: first, the non-negativity property of the KL divergence has the consequence, that the variational free energy F (q(ϑ)) is always smaller than or equal to the log model evidence, i.e.,
F (q(ϑ)) ≤ ln p(y). (15.7)
This fact can be exploited in the numerical application of variational inference to probabilistic models: because the log model evidence is a fixed quantity which only depends on the choice of p(y, ϑ) and a specific data realization, manipulating the variational density q(ϑ) for a given data set in such a manner that the variational free energy increases has two consequences: first, the lower bound to the log model evidence becomes tighter, and the variational free energy a better approximation to the log model evidence. Second, because the left-hand side of eq. (15.3) remains constant, the KL divergence between the true posterior and its variational approximation decreases, which renders the variational density q(ϑ) an increasingly better approximation to the true posterior distribution p(ϑ|y). Because the variational free energy is a lower bound to the log model evidence, it is also referred to as evidence lower bound (ELBO). The maximization of a variational free energy in terms of a variational density is a very general approach for posterior density and log model evidence approximation. Like the maximum likelihood approach, it serves rather as a guiding principle rather than a concrete numerical algorithm. In other words, algorithms that make use of the variational free energy log model evidence decomposition are jointly referred to as variational inference algorithms, but many variants exist. In the following two sections, we will discuss two specific variants and illustrate them with examples. The variants will be referred to as free-form mean-field variational inference and fixed-form mean-field variational inference. Here, the term mean-field refers to a factorization assumption with respect to the variational densities
PMFN | © 2019 Dirk Ostwald CC BY-NC-SA 4.0 Free-form mean-field variational inference 154 over s sets of the unobserved random variables,
s Y q(ϑ) = q(ϑi). (15.8) i=1 Such a factorization allows the variational free energy to be optimized independently for the variational densities q(ϑi) in a coordinate-wise fashion for i = 1, ..., s, a procedure sometimes referred to as coordinate ascent variational inference (CAVI). The free-form and fixed-form variants of mean-field variational inference then differ in their assumptions about the variational densities q(ϑi): the defining feature of the free-form mean-field variational inference approach is that the parametric form of variational densities is not predetermined, but analytically evaluated based on a central result from variational calculus. As such, the free-form mean-field variational inference approach is useful to emphasize the roots of variational inference in variational calculus, but is also analytically quite demanding. In a functional neuroimaging context, the free-form mean-field approach thus serves primarily didactic purposes. The fixed-form mean- field variational inference approach, on the other hand, is characterized by predetermined functional forms of the variational densities and of high practical relevance in functional neuroimaging. In particular, a fixed-form mean-field variational inference approach that rests on Gaussian variational densities enjoys wide-spread popularity in functional neuroimaging (under the label variational Bayes Laplace algorithm) and theoretical neuroscience (under the label free energy principle). In contrast to the free-form mean-field approach, the fixed-form mean-field approach is less analytically demanding and replaces a variational optimization problem with a standard numerical optimization problem. This is achieved by analytically evaluating the variational free energy in terms of the parameters of the variational densities.
15.2 Free-form mean-field variational inference
Free-form mean-field variational inference rests on a factorization of the variational density over sets of unobserved random variables q(ϑ) = q(ϑs)q(ϑ\s), (15.9) referred to as a mean-field approximation. In (15.9), ϑ\s denotes all unobserved variables not in the sth group. For the factorization (15.9), the variational free energy becomes a function of two arguments, namely q(ϑs) and q(ϑ\s). Due to the complexity of the integrals involved, a simultaneous analytical maxi- mization of the variational free energy with respect to both its arguments is often difficult to achieve, and a coordinate-wise approach, i.e., maximization first with respect to q(ϑs) and second with respect to q(ϑ\s), is preferred. Notably, the assumed factorization over sets of variables corresponds to the assumption, that the respective variables form stochastically independent contributions to the multivariate posterior, which, depending on the true form of the generative model, may have weak or strong implications for the validity of the ensuing posterior inference.
The question is thus how to obtain the arguments q(ϑs) and q(ϑ\s) that maximize the variational free energy. It turns out that this challenge corresponds to a well-known problem in statistical physics, which has long been solved in a general fashion using variational calculus (Hinton and Van Camp, 1993). In contrast to ordinary calculus, which deals with the optimization of functions with respect to real numbers, variational calculus deals with the optimization of functions (in the context of variational calculus also referred to as functionals) with respect to functions. Using variational calculus, it can be shown that the variational free energy is maximized with respect to the unobserved variable partitions ϑs, if q(ϑs) is set proportional (i.e., equal up to a scaling factor) to the exponential of the expected log joint probability of y and ϑ under the variational density over ϑ\s. Formally, this can be written as Z q(ϑs) ∝ exp q(ϑ\s) ln p(y, ϑ)dϑ\s (15.10)
The result stated in eq. (15.10) is fundamental. It represents the general free-form mean-field variational inference strategy to obtain variational densities over unobserved variables in light of data and maximizing the lower bound to the log model evidence. We thus refer to eq. (15.10) as the free-form variational inference theorem. In the following, we provide two proofs that (15.10) maximizes the variational free energy with respect to q(ϑs). The first proof is constructive in that it uses the constrained Gateaux derivative approach from variational calculus (Chapter 7), to generate the solution (15.10). The second proof eschews recursion to variational calculus techniques and uses a reformulation of the variational free energy in terms of a KL divergence involving the right-hand side of (15.10) (cf. Tzikas et al.(2008))
PMFN | © 2019 Dirk Ostwald CC BY-NC-SA 4.0 Free-form mean-field variational inference 155
Proof I of (15.10)
We first note that the aim of the free-form mean-field variational inference approach is to approximate the log marginal likelihood ln p(y) by iteratively maximizing its lower bound F (q(ϑs), q(ϑ\s)) with respect to the (i+1) arguments q(ϑs) and q(ϑ\s). For iterations i = 1, 2,..., during the first maximization of finding q (ϑs), (i) (i+1) (i+1) q (ϑ\s) is treated as a constant, while during the second maximization of finding q (ϑs), q (ϑ\s) is treated as a constant. However, as ϑs and ϑ\s may be used interchangeably, we here concern ourselves only with (i+1) the case of maximizing F (q(ϑs), q(ϑ\s)) with respect to q(ϑs). To obtain an expression for q (ϑs), we thus consider the variational free energy functional ZZ Z (i) (i) p(y, ϑ) F q(ϑs), q (ϑ\s) = q(ϑs)q (ϑ\s) ln (i) dϑdϑ\s, where q(ϑ) dϑ = 1. (15.11) q(ϑ)q (ϑ\s)
In this case, the extended Lagrange function (cf. Chapter 7) is given by Z ¯ (i) (i) p(y, ϑ) F q(ϑs), q (ϑ\s) = q(ϑs)q (ϑ\s) ln (i) dϑ\s + λq(ϑs) − λ. (15.12) q(ϑs)q (ϑ\s)
¯ (i) ¯ Furthermore, the Gateaux derivative δF q(ϑs), q (ϑ\s) is given by the derivative of F with respect to q(ϑs), 0 because F¯ is not a function of q (ϑs) . One thus obtains
¯ (i) δF q(ϑs), q (ϑ\s) ! ! ∂ Z p(y, ϑ) = q(ϑ )q(i)(ϑ ) ln dθ + λq(ϑ ) − λ s \s (i) s ∂q(ϑs) q(ϑs)q (ϑ\s) Z Z ∂ (i) (i) (i) = q(ϑs)q (ϑ\s) ln p(y, ϑ)dϑ\s − q(ϑs)q (ϑ\s) ln(q(ϑs)q (ϑ\s))dϑ\s + λq(ϑs) − λ ∂q(ϑs) Z Z Z ∂ (i) (i) (i) = q(ϑs) q (ϑ\s) ln p(y, ϑ)dϑ\s − q(ϑs)q (ϑ\s) ln q(ϑs)dϑ\s − q(ϑs)q (ϑ\s) ln q(ϑ\s)dϑ\s + λq(ϑs) − λ ∂q(ϑs) Z Z Z ∂ (i) (i) (i) = q(ϑs) q (ϑ\s) ln p(y, ϑ)dϑ\s − q(ϑs) ln q(ϑs) q (ϑ\s)dϑ\s − q(ϑs) q (ϑ\s) ln q(ϑ\s)dϑ\s + λq(ϑs) − λ ∂q(ϑs) Z Z (i) (i) (i) = q (ϑ\s) ln p(y, ϑ)dϑ\s − ln q(ϑs) − q (ϑ\s) ln q (ϑ\s)dϑ\s + λ Z (i) = q (ϑ\s) ln p(y, ϑ)dϑ\s − ln q(ϑs) + c, (15.13) where Z (i) (i) c := λ − q (ϑ\s) ln q (ϑ\s)dϑ\s. (15.14)
Setting the Gateaux derivative to zero thus yields Z (i+1) (i) ln q (ϑs) = q (ϑ\s) ln p(y, ϑ)dϑ\s + c. (15.15)
Taking the exponential and subsuming the multiplicative constant under the proportionality factor then yields free-form variational inference theorem for mean-field approximations Z (i+1) q (ϑs) ∝ exp q(ϑ\s) ln p(y, ϑ)dϑ\s . (15.16)
2
Proof II of (15.10)
Consider maximization of the variational free energy with respect to q(ϑs)
PMFN | © 2019 Dirk Ostwald CC BY-NC-SA 4.0 Free-form mean-field variational inference 156
Figure 15.2. The log model evidence decomposition visualized in Figure 15.1 is exploited in numerical algorithms for free- form VB inference: based on a mean-field approximation q(ϑ) = q(ϑs)q(ϑ\s), the variational free energy can be maximized in a coordinate-wise fashion. Maximizing the variational free energy in turn has two implications: it decreases the KL divergence between q(ϑ) and the true posterior p(ϑ|y) and renders the variational free energy a closer approximation to the log model evidence. This holds true, because the log model evidence for a given observationy ˜ is constant (represented by the constant length of the vertical bar) and the KL divergence is non-negative.
! ZZ p(y, ϑ) F (q(ϑs)q(ϑ\s)) = q(ϑs)q(ϑ\s) ln dϑsdϑ\ q(ϑs)q(ϑ\s) ZZ = q(ϑs)q(ϑ\s)(ln p(y, ϑ) − ln q(ϑs) − ln q(ϑ\s))dϑsdϑ\s ZZ ZZ = q(ϑs)q(ϑ\s)(ln p(y, ϑ) − ln q(ϑs))dϑ\sdϑs − q(ϑs)q(ϑ\s) ln q(ϑ\)dϑsdϑ\s ZZ Z Z = q(ϑs)q(ϑ\s)(ln p(y, ϑ) − ln q(ϑs))dϑ\sdϑs − q(ϑ\s) ln q(ϑ\s) q(ϑs)dϑs dϑ\s ZZ ZZ Z = q(ϑs)q(ϑ\s) ln p(y, ϑ)dϑ\sdϑs − q(ϑs)q(ϑ\s) ln q(ϑs)dϑ\sdϑs − q(ϑ\s) ln q(ϑ\s) · 1 dϑ\s Z Z Z Z = q(ϑs) q(ϑ\s) ln p(y, ϑ)dϑ\s dϑs − q(ϑs) ln q(ϑs)( q(ϑ\)dϑ\s)dϑs − c Z Z Z = q(ϑs) q(ϑ\s) ln p(y, ϑ)dϑ\s dϑs − q(ϑs) ln q(ϑs) · 1 dϑs − c Z Z Z = q(ϑs) q(ϑ\s) ln p(y, ϑ)dϑ\s − q(ϑs) ln q(ϑs)dϑs − c Z Z Z = q(ϑs) ln exp q(ϑ\s) ln p(y, ϑ)dϑ\s dϑs − q(ϑs) ln q(ϑs)dϑs − c Z Z = q(ϑs) ln exp q(ϑ\s) ln p(y, ϑ)dϑ\s − ln q(ϑs)dϑs − c R !! Z exp q(ϑ\) ln p(y, ϑ)dϑ\s = q(ϑs) ln dϑs − c q(ϑs) !! Z q(ϑs) = − q(ϑs) ln R dϑs − c exp q(ϑ\s) ln p(y, ϑ)dϑ\s = −KL q(ϑs)k exp ∫ q(ϑ\s) ln p(y, ϑ) dϑ\s − c. (15.17)
Maximizing the negative KL divergence by setting q(ϑs) = exp ∫ q(ϑ\s) ln p(y, ϑ) dϑ\s (15.18) thus maximizes the variational free energy. 2 Based on the free-form variational inference theorem, algorithmic implementations of variational in- ference can use an iterative coordinate-wise variational free energy ascent. For iterations i = 0, 1, 2,..., (0) (0) this strategy proceeds as follows. The ascent starts by initializing q (ϑs) and q (ϑ\s), commonly by equating them to the prior distributions over ϑs and ϑ\s, respectively. Based on (15.10) it then con- (i) (i) tinues by maximizing the variational free energy F (q (ϑs), q (ϑ\s)), first with respect to the density (i) (i) (i+1) q (ϑs), given q (ϑ\s), and yielding the updated density q (ϑs). Then, by exchanging the labelling of ϑs and ϑ\s in eq. (15.9), the ascent continues by maximizing the variational free energy with respect (i) (i+1) (i+1) to the density q (ϑ\s), given q (ϑs) and yielding q (ϑ\s). This procedure is then iterated until convergence. Commonly, the initialization step sets the variational density q(0)(ϑ) to the prior distribution p(ϑ). This defines the starting point of the iterative procedure as representative of the knowledge about the
PMFN | © 2019 Dirk Ostwald CC BY-NC-SA 4.0 Free-form mean-field variational inference 157 unknown variables before observed data is taken into account. Further, this choice often enables the use of the well-known benefits of parameterized conjugate priors in the context of variational inference. The initialization of the variational density in terms of the prior distribution, and the subsequent optimization of the variational densities should, however, not be confused with an empirical Bayesian approach, in which priors themselves are learned from the data: on each iteration of the variational inference algorithm sketched above, the variational density corresponds to the approximate posterior distribution, not an updated prior distribution. An empirical Bayesian extension of the variational inference algorithm on the other hand would correspond to a variation of the prior distribution (specifying the variational inference algorithm starting conditions) after convergence with the aim of increasing the log model evidence per se. The variational inference algorithm as described here merely increases the lower bound to the fixed log model evidence, which is determined by the choice of the prior p(ϑ) and likelihood p(y|ϑ), i.e, the generative model p(y, ϑ). To summarize above, a general iterative algorithm for free-form mean-field variational inference is outlined below. This iterative scheme shares some similarities with expectation- maximization algorithms for models comprising unobserved variables (Dempster et al., 1977; Wu, 1983). In fact, variational inference can be viewed as a generalization of expectation-maximization algorithms for maximum likelihood estimation to Bayesian inference. For the general linear model, this line of thought is further investigated in Chapter 22.
An iterative free-form mean-field variational inference algorithm. Initialization (0) (0) (0) (0) 0. Initialize q (ϑs) and q (ϑ\s) appropriately, e.g., by setting q (ϑs) ∝ p(ϑs) and q (ϑ\s) ∝ p(ϑ\s) Until convergence (i+1) R (i) 1. Set q (ϑs) proportional to exp( q (ϑ\) ln p(y, ϑ)dϑ\s) (i+1) R (i+1) 2. Set q (ϑ\s) proportional to exp( q (ϑs) ln p(y, ϑ)dϑs)
Free-form mean-field variational inference for a Gaussian model
Probabilistic model
To demonstrate the free form mean-field variational inference, we consider the estimation of the ex- pectation and precision parameter of a univariate Gaussian distribution based on n independently and identically distributed data realizations yi, i = 1, ..., n (Penny and Roberts, 2000; Bishop, 2006; Chappell et al., 2008; Murphy, 2012). To this end, we assume that the yi, i = 1, ..., n are generated by a univariate Gaussian distribution with true, but unknown, expectation parameter µ ∈ R and precision parameter T n λ > 0. We denote the concatenation of the data realizations by y := (y1, . . . , yn) ∈ R . To recapitulate, the aim of variational inference is, based on appropriately chosen prior densities, first, to obtain posterior densities that quantify the remaining uncertainty over the true, but unknown, unobservable variables given the observable variables, and second, to obtain an approximation to the log model evidence, i.e., the log probability of the data given the probabilistic model. In the current example, the probabilistic model takes the form of the joint probability density function,
n Y p(y, µ, λ) = p(µ, λ)p(y|µ, λ) = p(µ, λ) p(yi|µ, λ). (15.19) i=1 A possible choice for a prior joint density of the unobservable variables is given by the product of a univariate Gaussian density for µ and a Gamma density for λ, i.e.,
n 2 Y −1 p(y, µ, λ) := p(µ)p(λ)p(y|µ, λ) := N µ; mµ, sµ G (λ; aλ, bλ) N yi; µ, λ (15.20) i=1 Note that many other prior densities are conceivable. In fact, a more commonly discussed scenario is the case of a non-independent Gaussian-Gamma prior density (Bishop, 2006; Murphy, 2012). With respect to the factorized prior density considered here, the non-independent Gaussian-Gamma prior density has the advantage that it belongs to the conjugate-exponential class and allows for the derivation of an exact
PMFN | © 2019 Dirk Ostwald CC BY-NC-SA 4.0 Free-form mean-field variational inference 158 analytical solution for the form of the posterior distribution. On the other hand, it is not clear in which applied scenarios a dependency of the prior over the expectation parameter µ on the prior density over λ is in fact a reasonable assumption. We here thus focus on the factorized prior density, as it corresponds to a more parsimonious choice than its non-factorized counterpart. Furthermore, it demonstrates how variational inference can be used to derive posterior density approximations in model scenarios where no analytical treatment is possible.
Variational inference
For the posterior density, we consider the mean-field approximation
p(µ, λ|y) ≈ q(µ)q(λ). (15.21)
Recall that the free-form mean-field variational inference theorem states that the variational density over the unobservable variable partition ϑs is given by q(ϑs) ∝ exp ∫ q(ϑ\s) ln p(y, ϑ)dϑ\s . (15.22)
For the current example, q(µ, λ) := q(µ)q(λ), (15.23) and thus q(µ) = cµ exp (∫ q(λ) ln p(y, µ, λ) dλ) (15.24) and q(λ) = cλ exp (∫ q(µ) ln p(y, µ, λ) dµ) , (15.25) where cµ and cλ denote proportionality constants that render the proportionality statement in (15.22) equalities in (15.24) and (15.25), respectively. In the following, we shall derive an iterative scheme based on the equations above. For this purpose, it is first helpful to explicitly denote the iterative nature of the approach by denoting the variational densities q(µ) and q(λ) as q(i)(µ) and q(i)(λ). This also stresses the fact that in eqs. (15.24) and (15.25), the left-hand variational densities refer to their state at the (i + 1)th algorithm iteration, while the right-hand variational densities refer to their state at the ith algorithm iteration. Second, as we are dealing with densities from the exponential family, it is helpful to log transform both eqs. (15.24) and (15.25). For i = 0, 1, 2, ... eqs. (15.24) and (15.25) may thus be rewritten as Z (i+1) (i) ln q (µ) := q (λ) ln p(y, µ, λ) dλ + ln cµ (15.26) and Z (i+1) (i+1) ln q (λ) := q (µ) ln p(y, µ, λ) dµ + ln cλ (15.27)
To obtain an expression for q(i+1)(µ), we first note that we can express eq. (15.26) as
n (i+1) 1 X 2 1 2 ln q (µ) = − hλi (i) (y − µ) − (µ − m ) +c ˜ (15.28) 2 q (λ) i 2s2 µ µ i=1 µ wherec ˜µ denotes a constant including additive terms devoid of µ. Based on (15.28) and using the completing-the-square theorem for Gaussian distributions (cf. Chapter 10), we can then infer that q(i+1)(µ) is proportional to a Gaussian density
(i+1) (i+1) 2(i+1) q (µ) = N µ; mµ , sµ , (15.29) with parameters
2 Pn 2 (i+1) mµ + sµhλiq(i)(λ) i=1 yi 2(i+1) sµ mµ = 2 and sµ = 2 . (15.30) 1 + nsµhλiq(i)(λ) 1 + nsµhλiq(i)(λ)
Next, to obtain an expression for q(i+1)(λ), we first note that we can express eq. (15.27) as
n (i+1) n λ X 2 λ ln q (λ) = ln λ − h (y − µ) i (i+1) + (a − 1) ln λ − +c ˜ , (15.31) 2 2 i q (µ) λ b λ i=1 λ
PMFN | © 2019 Dirk Ostwald CC BY-NC-SA 4.0 Free-form mean-field variational inference 159
Figure 15.3. Free-form variational inference for the Gaussian. (A) The panels depict the true underlying data model p(y|µ, λ), for µ = 1 and λ = 5 as solid line and N = 10 samples yi from this model on the abscissa as red dots. Based on these samples, on each iteration of the VB algorithm, a variational approximation q(µ)q(λ) is updated. The first panel of (A) shows the univariate Gaussian model as approximated by the expectations overq(µ) and q(λ) as dashed line. The second panel of (A) shows the effect of the update of the density q(µ) on the first iteration of the algorithm. As q(µ) governs the mean of the univariate Gaussian, the dashed Gaussian is now centered on the mean of the data-points. The third panel of (A) shows the effect of the update of the density q(λ) on the first iteration of the algorithm. As q(λ) governs the precision of the univariate Gaussian model, the dashed Gaussian updates its variance based on the data variability. The fourth and fifth panels of (A) show the corresponding two steps on the 8th iteration. (B) The panels of (B) show the factorized variational density q(µ)q(λ) over VB algorithm iterations. The white dot in each panel indicates the true underlying parameters that gave rise to the observed data. Note that these parameters were not sampled from the prior density, but that the prior density embeds the initial uncertainty about this true, but unknown, parameter value before the observation of any data. The ordering of the panels is as in (A). (A) The panel shows the evolution of the variational free energy over iterations of the VB algorithm. For the current model and data set, the variational free energy levels off from approximately 4 iterations onwards. In the variational inference framework, the final value of the variational free energy after convergence of the algorithm corresponds to the approximation to the log model evidence ln p(y)
wherec ˜λ denotes a constant including additive terms devoid of λ. Expressing the right-hand side of (15.31) in multiplicative terms involving λ and taking exponentials, it then follows that q(i+1)(λ) is proportional to a Gamma density (i+1) (i+1) G λ; aλ , bλ (15.32) with parameters
n n !!−1 2 (i+1) n (i+1) 1 1 X 2 X (i+1) (i+1) 2(i+1) a = + a and b = + y − 2 yimµ + n mµ + s (15.33) λ 2 λ λ b 2 i µ λ i=1 i=1
A number of things are noteworthy. First, the Gaussian and Gamma density forms of the variational densities q(i+1)(µ) and q(i+1)(λ) follow directly from the form of the probabilistic model eq. (15.20) and the free-form mean-field theorem for variational inference. In other words, the functional forms of the densities q(i+1)(µ) and q(i+1)(λ) are not predetermined, but automatically fall out of the variational inference approach - hence free-form variational inference. Second, if the variational density q(0)(λ) is initialized using the prior density p(λ), the expected value hλiq(0)(λ) is determined by the prior param- (i) (i) eters aλ and bλ for i = 0, and by the variational density parameters aλ and bλ for i = 1, 2, .... In other words, the parameter update equations (15.30) and (15.33) are fully determined in terms of the 2 prior density parameters aλ, bλ, mµ, sµ, the data realizations y1, y2, ..., yn, and the variational density (i) 2(i) (i) (i) parameters mµ , sµ , aλ and bλ . Third, an explicit form of the variational free energy is not required for its maximization by means of the variational densities q(i)(µ) and q(i)(λ). It is nevertheless useful to evaluate it in order to monitor the progression of the iterative algorithm. For the current example, it takes the form
PMFN | © 2019 Dirk Ostwald CC BY-NC-SA 4.0 Free-form mean-field variational inference 160
n 2 (i) 2(i) (i) (i) 2 F : R × R × R>0 × R>0 × R>0 → R, y, mµ , sµ , aλ , bλ , mµ, sµ, aλ, bλ 7→
(i) 2(i) (i) (i) 2 F y, mµ , sµ , aλ , bλ , mµ, sµ, aλ, bλ n n ! 1 (i) (i) 1 (i) (i) X 2 (i) 2 2(i) (i) X := ψ a + ln b − a b y + N (mµ ) + s − 2mµ yi 2 λ λ 2 λ λ i µ i=1 i=1 2 2 2 (i) 2(i) (i) 1 s m + mµ + s − 2mµ mµ 1 + ln µ + µ µ − (15.34) 2(i) 2 2 sµ 2sµ 2 (i) (i) (i) (i) (i) + (bλ − 1)ψ bλ − ln aλ − bλ − ln Γ(bλ ) + ln Γ (bλ) (i) (i) (i) (i) aλ bλ + bλ ln aλ − (bλ − 1)(ψ bλ + ln aλ ) + , aλ where Γ and ψ denote the Gamma and digamma functions, respectively. We visualize the free-form mean-field variational inference approach for the expectation and precision parameter of a univariate Gaussian in Figure 15.3.
Proof of eqs. (15.28), (15.29) and (15.30)
We first note that with the probabilistic model (15.20), eq. (15.26) can be rewritten as
(i+1) ln q (µ) = hln p(y, µ, λ)iq(i)(λ) + ln cµ n Y = hln( p(yi|µ, λ)p(µ)p(λ))iq(i)(λ) + ln cµ i=1 (15.35) n X = h ln p(yi|µ, λ)iq(i)(λ) + hln p(µ)iq(i)(λ) + hln p(λ)iq(i)(λ) + ln cµ i=1 Substitution of the example-specific densities (cf. eq. (15.20))
2 −1 p(µ) = N(µ; mµ, sµ), p(λ) = G(λ; aλ, bλ), and p(yi|µ, λ) = N(yi; µ, λ ) (15.36) then yields n (i+1) X 1 − 1 λ 2 ln q (µ) = h ln λ 2 (2π) 2 exp − (y − µ) i (i) 2 i q (λ) i=1 1 1 2 − 2 2 + hln (2πsµ) exp − 2 (µ − mµ) 2sµ q(i)(λ) 1 1 aλ−1 λ + hln aλ λ exp − iq(i)(λ) Γ(aλ) bλ bλ + ln c µ (15.37) n n n λ X 2 = h ln λ − ln 2π − (y − µ) i (i) 2 2 2 i q (λ) i=1
1 2 1 1 2 + h− ln sµ − ln 2π − 2 (µ − mµ) iq(i)(λ) 2 2 2sµ λ + h− ln Γ(aλ) − aλ ln(bλ) + (aλ − 1) ln λ − iq(i)(λ) bλ
+ ln cµ
Grouping all terms devoid of µ in a constantc ˜µ and accounting for the linearity of expectations then results in.
n (i+1) 1 X 2 1 2 ln q (µ) = − hλi (i) (y − µ) − (µ − m ) +c ˜ (15.38) 2 q (λ) i 2s2 µ µ i=1 µ
Next, to use the completing-the-square theorem for inferring that q(i+1)(µ) conforms to a Gaussian density with the parameters of eq. (15.29), we first rewrite the right-hand side of eq. (15.28) as a quadratic expression in µ. We have
PMFN | © 2019 Dirk Ostwald CC BY-NC-SA 4.0 Free-form mean-field variational inference 161
n 2 (i+1) 1 X 2 1 (µ − mµ) ln q (µ) = − hλi (i) (yi − µ) − +c ˜µ 2 q (λ) 2 s2 i=1 µ n n ! 2 2 1 X 2 X 2 1 µ − 2µmµ + mµ = − hλi (i) y − 2µ yi + nµ − +c ˜µ 2 q (λ) i 2 s2 i=1 i=1 µ n n ! 1 X 2 X 2 1 2 2mµ 1 2 = − hλi (i) y − hλi (i) 2µ yi + hλi (i) nµ + µ − µ + m +c ˜µ (15.39) 2 q (λ) i q (λ) q (λ) s2 s2 s2 µ i=1 i=1 µ µ µ n n ! 1 2 1 2 X 2mµ X 2 1 2 = − hλi (i) nµ + µ − hλi (i) 2µ yi − µ + hλi (i) y + m +c ˜µ 2 q (λ) s2 q (λ) s2 q (λ) i s2 µ µ i=1 µ i=1 µ ! n ! n ! 1 1 2 (i) (i) X 2mµ X 2 1 2 = − hλi (i) n + µ − 2a b yi + µ + hλi (i) y + m +c ˜µ 2 q (λ) s2 λ λ s2 q (λ) i s2 µ µ i=1 µ i=1 µ
Resolving brackets and grouping terms devoid of µ with the constantc ˜µ, resulting in the new constant c˜µ, and re-expressing the coefficient of µ2 then results in
n ! (i+1) 1 1 2 X mµ ln q (µ) = − hλi (i) n + µ + hλi (i) y + µ + c˜ 2 q (λ) s2 q (λ) i s2 µ µ i=1 µ 2 ! n ! 1 hλiq(i)(λ)nsµ 1 2 X mµ = − + µ + hλi (i) y + µ + c˜ (15.40) 2 s2 s2 q (λ) i s2 µ µ µ i=1 µ 2 ! n ! 1 1 + nsµhλiq(i)(λ) 2 X mµ = − µ + hλi (i) y + µ + c˜ 2 s2 q (λ) i s2 µ µ i=1 µ Using the completing-the-square theorem (cf. Chapter 10) in the form
1 exp − ax2 + bx ∝ N x; a−1b, a−1 (15.41) 2 then yields (i+1) (i+1) 2(i+1) q (µ) ∝ N µ; mµ , sµ (15.42)
(i+1) 2(i+1) where the variational parameters mµ , sµ may be expressed in terms of the expectation of λ under the ith (i) 2 variational density q (λ), the prior parameters mµ and sµ, and the data yi as
2 !−1 2(i+1) 1 + nsµhλiq(i)(λ) sµ = 2 sµ (15.43) 2 sµ = 2 1 + nsµhλiq(i)(λ) and
2 n ! (i+1) sµ X mµ mµ = 2 hλiq(i)(λ) yi + 2 1 + ns hλi (i) s µ q (λ) i=1 µ 2 2 Pn ! s mµ + hλi (i) sµ yi µ q (λ) i=1 (15.44) = 2 2 1 + nsµhλiq(i)(λ) sµ 2 Pn mµ + sµhλiq(i)(λ) i=1 yi = 2 1 + Nsµhλiq(i)(λ) 2
PMFN | © 2019 Dirk Ostwald CC BY-NC-SA 4.0 Free-form mean-field variational inference 162
Proof of eq. (15.31) and eq. (15.33)
In analogy to the derivation of eq. (15.28) we have n (i+1) X 1 − 1 λ 2 ln q (µ) = h ln λ 2 (2π) 2 exp − (y − µ) i (i+1) 2 i q (µ) i=1 1 1 2 − 2 2 + hln (2πsµ) exp − 2 (µ − mµ) 2sµ q(i+1)(µ) 1 1 aλ−1 λ + hln aλ λ exp − iq(i+1)(µ) Γ(aλ) bλ bλ + ln c λ (15.45) n n n λ X 2 = h ln λ − ln 2π − (y − µ) i (i+1) 2 2 2 i q (µ) i=1
1 2 1 1 2 + h− ln sµ − ln 2π − 2 (µ − mµ) iq(i+1)(µ) 2 2 2sµ λ + h− ln Γ(aλ) − aλ ln(bλ) + (aλ − 1) ln λ − iq(i+1)(µ) bλ
+ ln cλ
Grouping all terms devoid of λ in a constantc ˜λ and using the linearity of expectations then simplifies the above to
n (i+1) n λ X 2 λ ln q (λ) = h ln λ − (yi − µ) iq(i+1)(µ) + h(aλ − 1) ln λ − iq(i+1)(µ) +c ˜λ 2 2 bλ i=1 n (15.46) n λ X 2 λ = ln λ − h (yi − µ) iq(i+1)(µ) + (aλ − 1) ln λ − +c ˜λ 2 2 bλ i=1 Reorganizing the right-hand side of equation (15.31) in multiplicative terms involving λ and expressing the expectations of µ under the variational density q(i+1)(µ) yields
n ! (i+1) n 1 1 X 2 ln q (λ) = + a − 1 ln λ − + h (yi − µ) i (i+1) λ +c ˜ 2 λ b 2 q (µ) λ λ i=1 n n ! n 1 1 X 2 X 2 = + a − 1 ln λ − + h y − 2µ yi + nµ i (i+1) λ +c ˜ 2 λ b 2 i q (µ) λ λ i=1 i=1 (15.47) n n !! n 1 1 X 2 X 2 = + a − 1 ln λ − + y − 2 yihµi (i+1) + nhµ i (i+1) λ +c ˜ 2 λ b 2 i q (µ) q (µ) λ λ i=1 i=1 n n !! 2 n 1 1 X 2 X (i+1) (i+1) 2(i+1) = + a − 1 ln λ − + y − 2 yimµ + n mµ + s λ +c ˜ . 2 λ b 2 i µ λ λ i=1 i=1
Taking the exponential on both sides then yields
n n ! ! n 1 1 2 (i+1) (i+1) ( 2 +aλ−1) X 2 X (i+1) (i+1) 2 q (λ) ∝ λ exp − − yi − 2 yimµ + n mµ + sµ λ . (15.48) bλ 2 i=1 i=1
Up to a normalization constant, q(i+1)(λ) is thus given by a Gamma density function
(i+1) (i+1) (i+1) q (λ) ∝ G λ; aλ , bλ (15.49) with parameters
n n !!−1 2 (i+1) N (i+1) 1 1 X 2 X (i+1) (i+1) 2(i+1) aλ := + aλ and bλ + yi − 2 yimµ + n mµ + sµ (15.50) 2 bλ 2 i=1 i=1
2
PMFN | © 2019 Dirk Ostwald CC BY-NC-SA 4.0 Free-form mean-field variational inference 163
Proof of eq. (15.34)
We first reformulate the variational free energy functional as Z p(y, ϑ) F (q(ϑ)) = q(ϑ) ln dϑ q(ϑ) Z p(y|ϑ)p(ϑ) = q(ϑ) ln dϑ q(ϑ) Z q(ϑ) = q(ϑ) ln p(y|ϑ)) − ln dϑ (15.51) p(ϑ) Z Z q(ϑ) = q(ϑ) ln (p(y|ϑ)) dϑ − q(ϑ) ln dϑ p(ϑ) Z = q(ϑ) ln (p(y|ϑ)) dϑ − KL(q(ϑ)kp(ϑ)), where the first term on the right-hand side is sometimes referred to as the average likelihood and the second term is the KL divergence between the variational and prior distributions. We next evaluate the average likelihood term. To this end, substitution of the relevant probability densities yields
Z ZZ n ! (i) (i) Y q(ϑ) ln(p(y|ϑ)) dϑ = q (µ)q (λ) ln N(yi; µ, λ) dµdλ i=1 ZZ 1 n ! (i) (i) λ 2 Y λ 2 = q (µ)q (λ) ln exp − (yi − µ) dµdλ 2π 2 i=1 ZZ n ! (i) (i) 1 1 λ X 2 = q (µ)q (λ) ln λ − ln 2π − (yi − µ) dµdλ (15.52) 2 2 2 i=1 ZZ ZZ n ! 1 (i) (i) (i) (i) λ X 2 1 = q (µ)q (λ) ln λdµdλ − q (µ)q (λ) (yi − µ) dµdλ − ln 2π 2 2 2 i=1 Z Z Z n ! ! 1 (i) (i) λ (i) X 2 1 = q (λ) ln λ dλ − q (λ) q (µ) (yi − µ) dµ dλ − ln 2π. 2 2 2 i=1
The first integral term on the right-hand side of eq. (15.52) is the expectation of the logarithm of λ under the (i) (i) (i) variational density q (λ) = G λ; aλ , bλ and evaluates to (cf. Johnson et al., 1994) Z (i) (i) (i) q (λ) ln λdλ = ψ aλ + ln bλ , (15.53) where ψ denotes the digamma function. The second integral term on the right-hand side of eq. (15.52) evaluates to
Z Z n ! ! Z Z n n ! ! (i) λ (i) X 2 (i) 1 (i) X 2 X 2 q (λ) q (µ) (yi − µ) dµ dλ = q (λ) λ q (µ)( y − 2µ yi + mµ dµ dλ 2 2 i i=1 i=1 i=1 Z n n Z Z ! 1 (i) X 2 X (i) (i) 2 = q (λ)λ y − 2 yi q (µ)µ dµ + n q (µ) µ dµ dλ 2 i i=1 i=1 n n ! 2 1 (i) (i) X 2 (i) X (i) 2(i) = a b y − 2mµ yi + n mµ + s . 2 λ λ i µ i=1 i=1 (15.54)
The average likelihood term in eq. (15.51) thus evaluates to
n n ! Z 2 (i) (i) 1 (i) (i) X 2 (i) X (i) 2(i) q(ϑ) ln (p(y|ϑ)) dϑ = ψ a + ln b + a b y − 2m y + n m + s . (15.55) λ λ 2 λ λ i µ i µ µ i=1 i=1 To evaluate the KL divergence term in eq. (15.51), we first note that with the additivity property of the KL divergence for factorized densities (cf. Chapter 12), we have
KL(q(µ)q(λ)||p(µ)p(λ)) = KL(q(µ)kp(µ)) + KL(q(λ)kp(λ)) (15.56)
In the current example, the variable µ is governed by Gaussian densities for both the prior density p(µ) and the variational densities. More specifically, in the variational inference algorithm, the prior density for µ has
PMFN | © 2019 Dirk Ostwald CC BY-NC-SA 4.0 Fixed-form mean-field variational inference 164
2 (i) parameters mµ, sµ, while the variational density q(µ) corresponds to the ith variational density q (µ) with (i) 2(i) parameters mµ , sµ . With the known form of the KL divergence for univariate Gaussian densities, we thus have
(i) (i) 2(i) 2 KL q (µ)kp(µ) = KL N µ; mµ , sµ kN µ; mµ, sµ
2 2 (i)2 2(i) (i) (15.57) 1 s m + mµ + s − 2mµ mµ 1 = ln µ + µ µ − . 2(i) 2 2 sµ 2sµ 2
Similarly, the variable λ is governed by Gamma densities for both the prior and and the variational densities. Specifically, the prior density of λ has parameters aλ and bλ, while the variational density q(λ) corresponds to (i) (i) the ith variational Gamma distribution over λ with parameters aλ and bλ . With the known form of the KL divergence for Gamma densities, we thus have
(i) (i) (i) KL q (λ)||p(λ) = KL G λ; a , b kG λ; a , b λ λ λ λ (i) (i) (15.58) (i) (i) (i) (i) (i) (i) (i) a b = (b − 1)ψ b − ln a − b − ln Γ(b ) + ln Γ b + b ln a − (b − 1)(ψ b + ln a ) + λ λ λ λ λ λ λ λ λ λ λ λ λ aλ
The KL divergence term in eq. (15.51) thus evaluates to
2 2 (i)2 2(i) (i) 1 s m + mµ + s − 2mµ mµ 1 KL(q(ϑ)kp(ϑ)) = ln µ + µ µ − 2(i) 2 2 sµ 2sµ 2 (i) (i) (i) (i) (i) (15.59) + (bλ − 1)ψ bλ − ln aλ − bλ − ln Γ bλ + ln Γ (bλ) (i) (i) (i) (i) aλ bλ + bλ ln aλ − (bλ − 1)(ψ bλ + ln aλ ) + aλ 2
15.3 Fixed-form mean-field variational inference
The central idea of fixed-form mean-field variational inference to pre-define the parametric form of the factorized variational density k Y q(ϑ) = q(ϑi) (15.60) i=1 at all stages of an iterative algorithm for the maximization of the variational free energy. Because the joint density p(y, ϑ) is defined during the formulation of the probabilistic model of interest, this entails that all densities of the variational free energy Z p(y, ϑ) F (q(ϑ)) = q(ϑ) ln dϑ (15.61) q(ϑ) are defined in parametric form at all times of the procedure. If the integral on the right-hand side of eq. (15.61) can be analytically evaluated (or at least be approximated) as a function of the parameters of the variational densities q(ϑi), i = 1, ..., k, the variational problem of maximizing a functional with respect to probability density functions is rendered a problem of multivariate optimization, which in turn can be addressed using the standard machinery of nonlinear optimization (Chapter 4). In the following, we will exemplify the fixed-form mean-field variational inference approach using a non-linear Gaussian model with a single mean-field partition, which forms of the basis for many models in functional neuroimaging (Friston, 2008).
Fixed-form variational inference for a non-linear Gaussian model
Probabilistic model
We consider the following hierarchical nonlinear Gaussian model comprising an unobservable random vector x and an observable random vector y
x = µx + η (15.62) y = f(x) + ε, (15.63)
PMFN | © 2019 Dirk Ostwald CC BY-NC-SA 4.0 Fixed-form mean-field variational inference 165
m n where x, µx, η ∈ R and y, ε ∈ R , ε and η are random vectors with distributions