15 | Variational Inference
Total Page:16
File Type:pdf, Size:1020Kb
15 j Variational inference 15.1 Foundations Variational inference is a statistical inference framework for probabilistic models that comprise unobserv- able random variables. Its general starting point is a joint probability density function over observable random variables y and unobservable random variables #, p(y; #) = p(#)p(yj#); (15.1) where p(#) is usually referred to as the prior density and p(yj#) as the likelihood. Given an observed value of y, the first aim of variational inference is to determine the conditional density of # given the observed value of y, referred to as the posterior density. The second aim of variational inference is to evaluate the marginal density of the observed data, or, equivalently, its logarithm Z ln p(y) = ln p(y; #) d#: (15.2) Eq. (15.2) is commonly referred to as the log marginal likelihood or log model evidence. The log model evidence allows for comparing different models in their plausibility to explain observed data. In the variational inference framework it is not the log model evidence itself which is evaluated, but rather a lower bound approximation to it. This is due to the fact that if a model comprises many unobservable variables # the integration of the right-hand side of (15.2) can become analytically burdensome or even intractable. To nevertheless achieve its two aims, variational inference in effect replaces an integration problem with an optimization problem. To this end, variational inference exploits a set of information theoretic quantities as introduced in Chapter 11 and below. Specifically, the following log model evidence decomposition forms the core of the variational inference approach (Figure 15.1): ln p(y) = F (q(#)) + KL(q(#)kp(#jy)): (15.3) In eq. (15.3), q(#) denotes an arbitrary probability density over the unobservable variables which is used as an approximation of the posterior density p(#jy). In the following, q(#) is referred to as variational density. In words, (15.3) states that for an arbitrary variational density q(#), the log model evidence comprises the sum of two information theoretic quantities: the so-called variational free energy, defined as Z p(y; #) F (q(#)) := q(#) ln d# (15.4) q(#) and the KL divergence between the true posterior density p(#jy) and the variational density q(#), Z q(#) KL(q(#)jjp(#jy)) = q(#) ln d#: (15.5) p(#jy) Based on these definitions, it is straightforward to show the validity of the log model evidence decompo- sition: Foundations 153 Figure 15.1. Visualization of the log model evidence decomposition that lies at the heart of the variational inference approach. The upper vertical bar represents the log model evidence, which is a function of the probabilistic model p(y; #) and is constant for any observation of y. As shown in the main text, the log model evidence can readily be rewritten into the sum of the variational free energy term F (q(#)) and a KL divergence term KL(q(#)jjp(#jy)), if one introduces an arbitrary variational density over the unobservable variables #. Maximizing the variational free energy hence minimizes the KL divergence between the variational density q(#) and the true posterior density p(#jy) and renders the variational free energy a better approximation of the log model evidence. Equivalently, minimizing the KL divergence between the variational density q(#) and the true posterior density p(#jy) maximizes the free energy and also renders it a tighter approximation to the log model evidence ln p(y). Proof of (15.3) By definition, we have Z p(y; #) F (q(#)) = q(#) ln d# q(#) Z p(y)p(#jy) = q(#) ln d# q(#) Z Z p(#jy) = q(#) ln p(y) d# + q(#) ln d# q(#) (15.6) Z Z p(#jy) = ln p(y) q(#) d# + q(#) ln d# q(#) Z q(#) = ln p(y) − q(#) ln d# p(#jy) = ln p(y) − KL(q(#)jjp(#jy)); from which eq. (15.3) follows immediately. 2 The log model evidence decomposition (15.3) can be used to achieve the aims of variational inference as follows: first, the non-negativity property of the KL divergence has the consequence, that the variational free energy F (q(#)) is always smaller than or equal to the log model evidence, i.e., F (q(#)) ≤ ln p(y): (15.7) This fact can be exploited in the numerical application of variational inference to probabilistic models: because the log model evidence is a fixed quantity which only depends on the choice of p(y; #) and a specific data realization, manipulating the variational density q(#) for a given data set in such a manner that the variational free energy increases has two consequences: first, the lower bound to the log model evidence becomes tighter, and the variational free energy a better approximation to the log model evidence. Second, because the left-hand side of eq. (15.3) remains constant, the KL divergence between the true posterior and its variational approximation decreases, which renders the variational density q(#) an increasingly better approximation to the true posterior distribution p(#jy). Because the variational free energy is a lower bound to the log model evidence, it is also referred to as evidence lower bound (ELBO). The maximization of a variational free energy in terms of a variational density is a very general approach for posterior density and log model evidence approximation. Like the maximum likelihood approach, it serves rather as a guiding principle rather than a concrete numerical algorithm. In other words, algorithms that make use of the variational free energy log model evidence decomposition are jointly referred to as variational inference algorithms, but many variants exist. In the following two sections, we will discuss two specific variants and illustrate them with examples. The variants will be referred to as free-form mean-field variational inference and fixed-form mean-field variational inference. Here, the term mean-field refers to a factorization assumption with respect to the variational densities PMFN j © 2019 Dirk Ostwald CC BY-NC-SA 4.0 Free-form mean-field variational inference 154 over s sets of the unobserved random variables, s Y q(#) = q(#i): (15.8) i=1 Such a factorization allows the variational free energy to be optimized independently for the variational densities q(#i) in a coordinate-wise fashion for i = 1; :::; s, a procedure sometimes referred to as coordinate ascent variational inference (CAVI). The free-form and fixed-form variants of mean-field variational inference then differ in their assumptions about the variational densities q(#i): the defining feature of the free-form mean-field variational inference approach is that the parametric form of variational densities is not predetermined, but analytically evaluated based on a central result from variational calculus. As such, the free-form mean-field variational inference approach is useful to emphasize the roots of variational inference in variational calculus, but is also analytically quite demanding. In a functional neuroimaging context, the free-form mean-field approach thus serves primarily didactic purposes. The fixed-form mean- field variational inference approach, on the other hand, is characterized by predetermined functional forms of the variational densities and of high practical relevance in functional neuroimaging. In particular, a fixed-form mean-field variational inference approach that rests on Gaussian variational densities enjoys wide-spread popularity in functional neuroimaging (under the label variational Bayes Laplace algorithm) and theoretical neuroscience (under the label free energy principle). In contrast to the free-form mean-field approach, the fixed-form mean-field approach is less analytically demanding and replaces a variational optimization problem with a standard numerical optimization problem. This is achieved by analytically evaluating the variational free energy in terms of the parameters of the variational densities. 15.2 Free-form mean-field variational inference Free-form mean-field variational inference rests on a factorization of the variational density over sets of unobserved random variables q(#) = q(#s)q(#ns); (15.9) referred to as a mean-field approximation. In (15.9), #ns denotes all unobserved variables not in the sth group. For the factorization (15.9), the variational free energy becomes a function of two arguments, namely q(#s) and q(#ns). Due to the complexity of the integrals involved, a simultaneous analytical maxi- mization of the variational free energy with respect to both its arguments is often difficult to achieve, and a coordinate-wise approach, i.e., maximization first with respect to q(#s) and second with respect to q(#ns), is preferred. Notably, the assumed factorization over sets of variables corresponds to the assumption, that the respective variables form stochastically independent contributions to the multivariate posterior, which, depending on the true form of the generative model, may have weak or strong implications for the validity of the ensuing posterior inference. The question is thus how to obtain the arguments q(#s) and q(#ns) that maximize the variational free energy. It turns out that this challenge corresponds to a well-known problem in statistical physics, which has long been solved in a general fashion using variational calculus (Hinton and Van Camp, 1993). In contrast to ordinary calculus, which deals with the optimization of functions with respect to real numbers, variational calculus deals with the optimization of functions (in the context of variational calculus also referred to as functionals) with respect to functions. Using variational calculus, it can be shown that the variational free energy is maximized with respect to the unobserved variable partitions #s, if q(#s) is set proportional (i.e., equal up to a scaling factor) to the exponential of the expected log joint probability of y and # under the variational density over #ns. Formally, this can be written as Z q(#s) / exp q(#ns) ln p(y; #)d#ns (15.10) The result stated in eq.