Bregman Duality in Thermodynamic Variational Inference

A. Conjugate Duality A.2. Dual The Bregman divergence associated with a convex We can leverage convex duality to derive an alternative f :⌦ R can be written as (Banerjee et al., 2005): divergence based on the conjugate function ⇤. !

⇤(⌘)=sup⌘ ()= ⌘ = () D [p : q]=f(p)+f(q) p q, f(q) · ) r Bf h r i = ⌘ ( ) (36) · ⌘ ⌘ The family of Bregman includes many familiar quantities, including the KL divergence corresponding to The conjugate measures the maximum between the negative entropy generator f(p)= p log pd!. Geo- the line ⌘ and the function (), which occurs at the · metrically, the divergence can be viewed as the difference unique point ⌘ where ⌘ = (). This yields a bijective R r between f(p) and its linear approximation around q. Since mapping between ⌘ and for minimal exponential fami- f is convex, we know that a first order estimator will lie lies (Wainwright & Jordan, 2008). Thus, a distribution p may be indexed by either its natural parameters p or mean below the function, yielding Df [p : q] 0. parameters ⌘p. For our purposes, we can let f , () = log Z over Noting that ( ⇤)⇤ = ()=sup ⌘ ⇤(⌘) (Boyd the domain of probability distributions indexed by natural ⌘ · & Vandenberghe, 2004), we can use a similar argument as parameters of an (e.g. (13)) : above to write this correspondence as = ⇤(⌘).We r⌘ can then write the dual divergence D ⇤ as: D [ : ]= ( ) ( ) , ( ) p q p q h p q r q i (34) D [⌘p : ⌘q]= ⇤(⌘p) ⇤(⌘q) ⌘p ⌘q, ⌘ ⇤(⌘q) ⇤ h r i = ⇤(⌘ ) ⇤(⌘ ) ⌘ + ⌘ p q p · q q · q This is a common setting in the field of information geom- = ⇤(⌘p)+ (q) ⌘p q (37) etry (Amari, 2016), which introduces dually flat manifold · structures based on the natural parameters and the mean where we have used (36) to simplify the underlined terms. parameters. Similarly,

D [ : ]= ( ) ( ) , ( ) A.1. KL Divergence as a Bregman Divergence p q p q h p q r q i = (p) (q) p ⌘q + q ⌘q For an exponential family with partition function () and · · = ( )+ ⇤(⌘ ) ⌘ (38) sufficient T (!) over a random variable !, the Breg- p q p · q man divergence D corresponds to a KL divergence. Re- Comparing (37) and (38), we see that the divergences are calling that ()=⌘ = E⇡ [T (!)] from (16), we r simplify the definition (34) to obtain equivalent with the arguments reversed, so that:

D [p : q]=D ⇤ [⌘q : ⌘p] (39) D [ : ]= ( ) ( ) ⌘ + ⌘ p q p q p · q q · q = (p) (q) Eq[p T (!)] This indicates that the Bregman divergence D should also · ⇤ be a KL divergence, but with the same order of arguments. + Eq[q T (!)] · We derive this fact directly in (44) , after investigating the = Eq q T (!) (q) + Eq[⇡0(!)] form of the conjugate function ⇤. · ⇥ log q(!) ⇤ A.3. Conjugate ⇤ as Negative Entropy | Eq p T (!){z (p) Eq[⇡0}(!)] · We first treat the case of an exponential family with no base ⇥ log p(!) ⇤ measure ⇡0(!), with derivations including a base measure | q(!) {z } in App. A.4. For a distribution p in an exponential family, = Eq log p(!) indexed by or ⌘ , we can write log p(!)= T (!) p p p · = D [q(!) p(!)] (35) (). Then, (36) becomes: KL || ⇤(⌘p)=p ⌘p (p) (40) where we have added and subtracted terms involving the · = p Ep[T (!)] (p) (41) base measure ⇡0(!), and used the definition of our expo- · = log p(!) nential family from (13). The Bregman divergence D is Ep (42) thus equal to the KL divergence with arguments reversed. = H (!) (43) p Bregman Duality in Thermodynamic Variational Inference since p and (p) are constant with respect to !. Utilizing B. Renyi Divergence Variational Inference ⇤(⌘p)=Ep log p(!) from above, the dual divergence with q becomes: In this section, we show that each intermediate partition function log Z corresponds to a scaled version of the Renyi´ D [⌘p : ⌘q]= ⇤(⌘p) ⇤(⌘q) ⌘p ⌘q, ⌘ ⇤(⌘q) VI objective (Li & Turner, 2016). ⇤ h r i L↵ = Ep log p(!) ⇤(⌘q) ⌘p q + ⌘q q · · To begin, we recall the definition of Renyi’s ↵ divergence. = Ep log p(!) ⌘p q + (q) · 1 1 ↵ ↵ = Ep log p(!) Ep[T (!) q]+ (q) D↵[p q]= log q(!) p(!) d! · || ↵ 1 Z = Ep log p(!) Ep log q(!) = DKL[p(!) q(!)] (44) Note that this involves geometric mixtures similar to (14). || Pulling out the factor of log p(x) to consider normalized distributions over z x, we obtain the objective of Li & Thus, the conjugate function is the negative entropy and in- | duces the KL divergence as its Bregman divergence (Wain- Turner (2016). This is similar to the ELBO, but instead wright & Jordan, 2008). subtracts a Renyi divergence of order ↵. Note that, by ignoring the base distribution over !, we have 1 () = log q(z x) p(x, z) d z instead assumed that ⇡0(!):=u(!) is uniform over the | domain. In the next section, we illustrate that the effect of Z = log p(x) (1 )D [p (z x) q (z x)] adding a base distribution is to turn the conjugate function ✓ | || | ⇡ (!) = log p(x) D1 [q(z x) p✓(z x)] into a KL divergence, with the base 0 in the second | || | argument. This is consistent with our derivation of negative := 1 entropy, since D [p (!) u(!)] = H (⌦) + const. L KL || p where we have used the skew symmetry property ↵ A.4. Conjugate ⇤ as a KL Divergence D↵[p q]= 1 ↵ D1 ↵[q p] for 0 <↵<1 (Van Erven || || & Harremos, 2014). Note that 0 =0and 1 = log p✓(x) As noted above, the derivation of the conjugate ⇤(⌘) in L L as in Li & Turner (2016) and Sec. 3. (40)-(43) ignored the possibilty of a base distribution in our exponential family. We see that ⇤(⌘) takes the form of a KL divergence when considering a base measure ⇡0(!). C. TVO using Taylor Remainders

⇤(⌘)=sup ⌘ () (45) Recall that in Sec. 4, we have viewed the KL divergence · D [ : 0] as the remainder in a first order Taylor approxi- = ⌘ ⌘ (⌘) mation of () around 0. The TVO objectives correspond · to the linear term in this approximation, with the gap in = E⇡⌘ [⌘ T (!)] (⌘) · TVOL(✓,, x) and TVOU (✓,, x) bounds amounting to a = E⇡ [⌘ T (!)] (⌘) E⇡ [log ⇡0(!)] ⌘ · ± ⌘ sum of KL divergences or Taylor remainders. Thus, the TVO may be viewed as a first order method. = E⇡ [log ⇡⌘ (!) log ⇡0(!)] ⌘

= DKL[⇡⌘ (!) ⇡0(!)] (46) Yet we may also ask, what happens when considering other || approximation orders? We proceed to show that thermody- Note that we have added and subtracted a factor of namic integration arises from a zero-order approximation, ⇡ log ⇡0(!) in the fourth line, where our base mea- E ⌘ while the symmetrized KL divergence corresponds to a sim- sure ⇡ (!)=q(z x) in the case of the TVO. Compar- 0 | ilar application of the fundamental theorem of in ing with the derivations in (41)-(42), we need to include a the mean parameter space ⌘ = (). In App. E, we r term of Ep⇡0(!) in moving to an expected log-probability briefly describe how ‘higher-order’ TVO objectives might Ep log p(!), with the extra, subtracted base measure term be constructed, although these will no longer be guaranteed transforming the negative entropy into a KL divergence. to provide upper or lower bounds on likelihood. In the TVO setting, this corresponds to We will repeatedly utilize the form of the Taylor remainder theorem, which characterizes the error in a k-th ⇤(⌘)=DKL[⇡ (z x) q(z x)] . (47) ⌘ | || | order approximation of (x) around a, with [a, x]2. 2 This identity can be derived using the fundamental theo- When including a base distribution, the induced Bregman rem of calculus and repeated (see, e.g. divergence is still the KL divergence since, as in the deriva- tion of (35), both Ep log p(!) and Ep log q(!) will contain 2We use generic variable x, not to be confused with data x, for terms involving the base distribution Ep log ⇡0(!). notational simplicity. Bregman Duality in Thermodynamic Variational Inference

(Kountourogiannis & Loya, 2003) and references therein): sentation of the KL divergence:

k (k+1) x p✓(x, z) () k R (x)= r (x ) d (48) DKL! [⇡k 1 ⇡k ]= (k ) Var⇡ log d k || q (z x) a k! kZ 1 | Z (53) C.1. Thermodynamic Integration as 0th Order Remainder Following similar arguments in the reverse direction, we can obtain an integral form for the TVO upper bound gap Consider a zero-order Taylor approximation of (1) around R1 (k 1)=DKL[⇡k ⇡k 1 ] via the first-order approxi- a =0, which simply uses (0) as an estimator. Applying || mation of (k 1) around a = k. the remainder theorem, we obtain the identity (6) underly- ing thermodynamic integration in the TVO: R1 (k 1)= (k 1) (k)+(k 1 k) (k) r (1) = (0) + R (1) =(k k 1) (k) ( (k) (k 1)) 0 (49) r 1 k 1 (1) (0) = ()d (50) 2 () r 1 Z0 = r (k 1 ) d (54) 1 1! p✓(x, z) Zk log p✓(x)= E⇡ log d (51) 0 q(z x) Z  | Note that the TVO upper bound (10) arises from the sec- ond line, with R1 (k 1) 0 and (k k 1) (k) where the last line follows as the definition of ⌘ = r ()= log Z in (16). corresponding to a right-Riemann approximation. r r Note that this integration is symmetric, in that approximat- Switching the order of integration in (54), we can write the ing (0) using (1) leads to an equivalent expression after KL divergence as reversing the order of integration. k p✓(x, z) D [⇡ ⇡ ]= ( ) Var log d st KL k k 1 k 1 ⇡ C.2. KL Divergence as 1 Order Remainder || q(z x) kZ 1 | We can apply a similar approach to the first order Taylor (55) approximations to reinterpret the TVO bound gaps in (9) and (10), although our remainder expressions will no longer While these integral expressions for the KL divergence may be symmetric. We will thus distinguish between estimating not be immediately intuitive, our use of the Taylor remainder (x) around axusing R1!(x) and R1 (x), theorem unifies their derivation with that of thermodynamic respectively, with the arrow indicating the direction of inte- integration. Alternative derivations may also be found in gration. Dabak & Johnson (2002).

Estimating (k) using a first order approximation around a = k 1 as in the TVO lower bound, the remainder exactly C.3. Symmetrized KL Divergence matches the definition of the Bregman divergence in (19): Combining the expressions for the KL divergence in Eq. (53) and (55) immediately leads to a known result relating R1!(k)= (k) (k 1)+(k k 1) (k 1) the symmetrized KL divergence to the integral of the Fisher r First-Order Taylor Approx information along the geometric path (Amari, 2016; Dabak & Johnson, 2002). k 2 | {z } () 1 = r (k ) d (52) k 1! p (x, z) Z ✓ k 1 DKL$ =(k k 1) Var⇡ log d (56) q(z x) kZ 1  | where (52) corresponds to the Taylor remainder from (48). Recall that this Bregman divergence D [k : k 1] cor- where we have defined the symmetrized KL divergence as: responds to a KL divergence DKL! [⇡k 1 ⇡k ] and con- || DKL$ [k 1; k]=DKL! [⇡k 1 ⇡k ]+DKL [⇡k ⇡k 1 ] tributes to the gap in TVOL(✓,, x). || || Simplifying the Taylor remainder expression, with Our goal in this section will be to show that (56) arises 2 p✓ (x,z) () = Var⇡ log q (z x) , we obtain an integral repre- from similar ‘thermodynamic integration’ on the graph of r | Bregman Duality in Thermodynamic Variational Inference the mean parameters ⌘. Recall that we previously applied AIS estimators. Using the Central Limit Theorem, Neal the fundamental theorem of calculus to () = log Z to (2001) show that the variance of an AIS estimator is mono- obtain the difference in log-partition functions tonically related to TVOL(✓,, x) under perfect transitions,

k or independent, exact samples from each intermediate (see Grosse et al. (2013) Eq. 3). However, note that AIS (k) (k 1)= ()d r estimates expectations over chains of MCMC samples rather kZ 1 than the simple reweighting used in the TVO. We can obtain a similar expression for the mean parameters In this section, we provide additional perspective on the ⌘ = () by integrating over the second . r analysis of Grosse et al. (2013), which considers the k asymptotic behavior of the scaled gap in TVOL(✓,, x), 2 K DKL[⇡k 1 ⇡k ], as K . ⌘k ⌘k 1 = ()d (57) r · ! || !1 kZ 1 We begin by restating Theorem 1 of Grosse et al. (2013) for the case of the full TVO objective. We describe the resulting 2 p✓ (x,z) Recalling that () = Var⇡ log q (z x) , we see that ‘coarse-grained’ linear binning schedule for choosing k r | { } the integrands in (56) and (57) are identical. Integrating in D.1 and provide further analysis in D.2. with respect to , we obtain the ‘area of a rectangle’ identity Theorem 1 (Grosse et al. (2013)). Suppose K +1distribu- for the symmetrized KL divergence (as in (30)): tions ⇡ K are linearly spaced along a path . Under { k }k=0 P k the assumption of perfect transitions, if the Fisher informa- p✓(x, z) tion matrix G() is smooth, then as K : DKL$ [k 1; k]=k Var⇡ log d · q (z x) !1  kZ 1 | K 1 1 k ˙ ˙ K DKL! [⇡k 1 ⇡k ] (t) G (t) (t)dt || ! 2 · · =( ) 2 ()d k=1 Z k k 1 X 0 r (59) kZ 1 k 1 =(k k 1) () = DKL! [⇡0 ⇡K ]+DKL [⇡K ⇡0 ] k 1 r 2 || || =(k k 1)(⌘k ⌘k 1) (58) Here, we let t [0, 1] parameterize the path (t)=(1 2 ˙ This identity is best understood via Fig. 5 in Sec. 4.4. t) 0 + t K , and let (t) denote the derivative of the · · parameter with respect to t. For linear mixing of the To summarize, we have given several equivalent ways of un- natural parameters as above, this is a constant: ˙(t)= derstanding the symmetrized KL divergence.The ‘forward’ K 0. In the case of the full TVO integrand, ˙(t)=1. and ‘reverse’ KL divergences arise as gaps in the TVO left- and right-Riemann approximations (Figure 5), or first or- Proof. See (Grosse et al., 2013) for a detailed proof, which der Taylor remainders as in (53) and (55). Summing these proceeds by taking the Taylor expansion of D [ + quantities corresponds to the area of a rectangle (58) on KL k|| k ] around each k for small . In particular, = the graph of the TVO integrand ⌘ , or to the integral of a 1 ( ) for linearly spaced =(1 k ) + k . variance term via the Taylor remainder theorem (56) or K K 0 k K · 0 K · K We assume w.l.o.g. =1and = 1 as in Grosse fundamental theorem of calculus (57). K 0 K et al. (2013) or TVO. Note that the TVO integrand ⌘ = ()= The zero- and first-order terms vanish, and the second-order p✓ (x,z) r ⇡ [log ] will be linear when its derivative, the 2 1 E q(z x) = | term, with K2 , can be written as (see e.g. Kullback variance of the log importance weights, is constant within (1997) p. 26): [k 1,k]. The KL divergence is actually symmetric 2 K K in this case, which we treat in more detail in the next section 1 K D [ + ]=K ˙ G( ) ˙ (App. D). More generally, the curvature of the integrand KL k|| k · 2K2 k · k · k indicates which direction of the KL divergence has larger k=1 k=1 X X3 + K (K ) (60) magnitude, and Figure 5 reflects our empirical observations ·O 1 that DKL[⇡k 1 ⇡k ] >DKL[⇡k ⇡k 1 ]. || || 1 ˙(t)G (t) ˙(t)dt (61) ! 2 Z D. Asymptotic Linear Scheduling Analysis 0 1 Grosse et al. (2013) treat a quantity identical to where we have absorbed = K into a continuous mea- TVO (✓,, x) in the context of analysing the variance of sure dt as K . L !1 Bregman Duality in Thermodynamic Variational Inference

We now show that this expression corresponds to the sym- 1/2 . Then, the D scaled by K becomes · · ⌘ KL metrized KL divergence, as in (Amari, 2016; Dabak & John- K K son, 2002). While this was not stated in the theorem of 1 K DKL! [⇡k 1 ⇡k ] K ⌘ (64) Grosse et al. (2013), it has also been shown by e.g. Huszar || ! 2 · · kX=1 kX=1 (2017). Observe that G()= ()=Var⇡ [T (x, z)] K r 1 K 0 as in (56) and (58). Noting that the implies = K ( ) (⌘k ⌘k 1) d G (t) = d G (t) d , we can pull one term of 2 K · dt d dt kX=1 d ˙ 1 dt = (t)=(K 0) outside the integral and perform = (K 0) (⌘K ⌘0) . integration by substitution. Ignoring the 1/2 factor, 2 · where, in the last line, we cancel factors of K and note the 1 K ( ) d ( ) cancellation of intermediate ⌘k in the telescoping sum. K 0 G (t) dt = K 0 2 ()d 2 dt 2 r Z Z 0 0 Thermodynamic Interpretation: This limiting behavior 1 is also discussed in thermodynamics, where the LHS of (64) = ( )(⌘ ⌘ ) (62) 2 K 0 K 0 and (59) corresponds to the rate of entropy production in

1 transitioning a system from ⇡0 to ⇡1 along a path defined = D [⇡ ⇡ ]+D [⇡ ⇡ ] KL! 0 K KL K 0 by . The condition that K refers to the linear 2 || || { k} !1 response regime, with (59) related to the thermodynamic divergence (Crooks, 2007).

D.1. ‘Coarse-Grained’ Linear Schedule Exponential and Mixture Geodesics: As in the state- Grosse et al. (2013) then use this asymptotic condition (62) ment of Theorem 1, we can more generally consider con- as K to inform the choice of a discrete partition necting two distributions, indexed by natural parameters !1K 0 and 1, using a parameter t [0, 1]. The curve = k k=0. 2 P { } =(1 t) + t then corresponds to our path t · 0 · 1 More concretely, consider dividing the interval [0, 1] into exponential family (13), and is also referred to as the e- J equally-spaced knot points J . We then allocate a { j}j=0 geodesic in information geometry Amari (2016). total budget of K = J K intermediate distributions j=1 j Similarly, the moment-averaged path of Grosse et al. (2013), across sub-intervals [j 1,j], with uniform linear spacing which also underlies our scheduling strategy in Sec. 5, of the K partitions withinP each sub-interval. j can be viewed as a linear mixture in the mean parameter Using (62), Grosse et al. (2013) assign a cost F = space. The m-geodesic then refers to the curve ⌘t =(1 j (j j 1)(⌘j ⌘j 1) to each ‘coarse-grained’ interval t) ⌘0 + t ⌘1 (Amari, 2016). Note that these mixtures · · [j 1,j]. Minimizing Fj subject to Kj = K, the reference different distribution for the same parameter t, so j j allocation rule becomes: that ⌘ = ⌘t. P P t 6 Grosse et al. (2013) proceed to show that the expression K ( )(⌘ ⌘ ) (63) j / j+1 j j+1 j for the symmetric KL divergence (59) corresponds to the q integral of the Fisher information along either the geometric We observe that performance when using this method can or mixture paths (Theorem 2 of Grosse et al. (2013), The- be sensitive to the number of knot points used, and we found orem 3.2 of Amari (2016)). The union of the intermediate J = 20 to perform best in our experiments. distributions integrated by these two paths coincide in our one-dimensional exponential family, although this intuition D.2. Additional Perspectives on Grosse et al. (2013) does not appear to translate to higher dimensions. Geometric Intuition for Theorem 1: To further under- stand Theorem 1 of Grosse et al. (2013), observe that E. Higher Order TVO the TVO integrand will appear linear within any interval While the convexity of the log-partition function yields the [k 1,k] as K . For general endpoints 0 and K , !1 K 0 family of Bregman divergences from the remainder in the we let = k k 1 = . K first order Taylor approximation, we might also consider Having already visualized the symmetrized KL diver- higher order terms to obtain tighter bounds on likelihood or gence as the area of a rectangle in Figure 5, we can see analyse properties of the TVO integrand (). We give r that each directed KL divergence, DKL! [⇡k 1 ⇡k ] and an example derivation for a second-order TVO objectives, || DKL [⇡k ⇡k 1 ], will approach the area of triangle as the although these are no longer guaranteed to be upper or lower || integrand becomes linear or K , with area equal to bounds on likelihood. !1 Bregman Duality in Thermodynamic Variational Inference

Left-to-Right Expansion We first consider expanding Model Following (Burda et al., 2015), we use a varia- the approximations in the TVO left-Riemann sum to sec- tional autoencoder (Kingma & Welling, 2013) with a 50- ond order. We denote the resulting objective (2), since we dimensional stochastic layer, z 50 L! 2R move ‘left-to-right’ in estimating (k) around k 1.We p (x, z)=p (x z)p(z) begin by writing the second-order Taylor approximation: ✓ ✓ | p(z)= (z 0, I) (k) (k 1)+(k k 1) (k 1) N | ⇡ r p✓(x z)=Bern(x decoder✓(z)) 1 2 2 | | + (k k 1) (k 1) (65) q(z x)= (z; µ(x), (x)) 2 r | N where the encoder and decoder are each two-layer MLPs While TVOL(✓,, x) consists of the first-order term alone, we can also consider adding the non-negative, second-order with tanh activations and 200 hidden dimensions. The out- term to form the objective (2). Using successive Taylor put of the encoder is duplicated and passed through an L! additional linear layer to parameterize the mean and log- approximations of (k), we obtain similar telescoping cancellations to obtain standard deviation of a conditionally independent Normal distribution. The output of the decoder is a sigmoid which K (2) parameterizes the probabilities of the independent Bernoulli log p(x) = log p(x) (k k 1) ⌘k 1 L! · distribution. ✓ and refer to the weights of the decoder and kX=1 encoder, respectively. K 1 p(x, z) ( )2 Var log k k 1 ⇡k 1 2 q(z x) Dataset We use Omniglot (Lake et al., 2013), a dataset kX=1 | (66) of 1623 handwritten characters across 50 alphabets. Each datapoint is binarized 28 28 image, i.e x 0, 1 784, ⇥ 2{ } where ⌘ = log p(x,z) . where we follow the common procedure in the literature of k 1 E⇡k 1 q(z x) | sampling each binary-valued observation with expectation We previously obtained a lower bound on log-likelihood equal to the real pixel value (Salakhutdinov & Murray, 2008; via this construction, with log p(x) TVOL(✓,, x) 0. Burda et al., 2015). We split the dataset into 24,345 training (2) However, will only provide a lower bound if () and 8,070 test examples. L! r is concave, i.e. 3 () 0. To see this, we write the r  Taylor remainder (48) as Training Procedure All models are written in PyTorch

k and trained on GPUs. For each scheduler, we train for 5000 1 3 3 epochs using the Adam optimizer (Kingma & Ba, 2017) R2!(k)= (k t) (t)dt (67) 3 6 r with a learning rate of 10 , and minibatch size of 1000. kZ 1 All weights are initialized with PyTorch’s default initializer. with the third derivative equal to 3 3 2 G. Implementation Details ()=E⇡ [T (x, z) ] 3[E⇡ T (x, z)] E⇡ [T (x, z) ] r · 3 While the Legendre transform, mapping between a target + 2[E⇡ T (x, z)] value of expected sufficient statistics ⌘ = E⇡ [T (!)] and the appropriate natural parameters , can be a difficult prob- 3 3 = E⇡ [T (x, z) ] [E⇡ T (x, z)] lem in general, we describe how to efficiently implement our ‘moments-spacing’ schedule in the context of TVO. 3[E⇡ T (x, z)] [Var⇡ T (x, z)] · In addition to indicating that (2) is a lower bound Recall from Sec. 5 that we are interested in finding a discrete L! K on log p (x), testing the concavity of () using partition = k k=0 such that: ✓ r P { } 3 () 0 can also indicate whether a trapezoid approxi- 1 k k r  = ⌘ (1 ) ELBO + EUBO (68) mation to the TVO integral provides a valid lower bound. k K · K · ✓ ◆ We can give an identical construction for the reverse direc- In other words, we seek to find the such that (2) k tion or higher order approximations. We leave a full [log p✓ (x,z) ] ⌘ , where ⌘ are equally spaced be- L E⇡k q (z x) k k exploration of these objectives for future work. | ⇡ tween the ELBO and EUBO (see Figure 6). F. Experimental Setup More concretely, we provide pseudo-code implementing our moments spacing schedule below. Given a set of S Code for all experiments can be found at https:// log-importance weights per sample, and a number of inter- github.com/vmasrani/tvo_all_in. mediate distributions K: Bregman Duality in Thermodynamic Variational Inference

Figure 10. Pseudo-code Implementation of Moments Scheduling for TVO

1 # Calculate expected sufficient statistics 1 eta at a given beta (Eq. 12) 2 def moments_spacing_schedule(log_iw, K, 2 def calc_eta(log_iw, beta): search=’binary’): 3 # 1) Exponentiate/normalize over 3 # 1) Calculate target values for uniform importance sample dimension moments spacing 4 snis = torch.exp(log_iw*beta - 4 elbo = calc_eta(log_iw, 0) 5 torch.logsumexp(log_iw*beta, 5 eubo = calc_eta(log_iw, 1) 6 dim = 1, keepdim=True)) 6 targets = [(1-t)*elbo+t*eubo 7 # 2) Take mean over data examples 7 for t in np.linspace(0,1,K+1)] 8 return torch.mean(snis*log_iw, dim =0) 8 9 9 # 2) Find beta corresponding to each 10 def binary_search(target, log_iw, start=0, target (including beta=0,1) stop=1, threshold = 0.1): 10 beta_schedule = [0] 11 11 12 beta_guess = .5*(stop-start) 12 for _k in range(1, K): 13 eta_guess = calc_eta(log_iw,beta_guess)13 target_eta = targets[_k] 14 if eta_guess > target + threshold: 14 15 return binary_search( 15 beta_k = binary_search( 16 target, 16 target_eta, 17 log_iw, 17 log_iw, 18 start=beta_guess, 18 start = 0, 19 stop=stop) 19 stop = 1) 20 elif eta_guess < target - threshold: 20 21 return binary_search( 21 beta_schedule.append(beta_k) 22 target, 22 23 log_iw, 23 beta_schedule.append(1) 24 start=start, 24 # 3) Return beta_schedule: used for 25 stop=beta_guess) Riemann approximation points in TVO 26 else: objective 27 return beta_guess 25 return beta_schedule

H. Additional Results 1, which reflects the shape of the TVO integrand, is nearly identical at 1 0.30 across both MNIST and Omniglot. In this section, we report wall-clock runtimes and run similar ⇡ experiments as in Sec. 8 to evaluate our moments spacing Comparison with IWAE We compare TVO using our mo- schedule and reparameterized on the binarized ments scheduling against the IWAE and IWAE DREG as in MNIST dataset (Salakhutdinov & Murray, 2008). Fig. 9 of the main text. We find that our TVO reparame- terized estimator achieves nearly identical model Wall-Clock Times We report wall clock runtimes for var- learning performance as IWAE and IWAE DREG, with notably ious scheduling methods with S = 50 and K =5in Fig. improved posterior inference for all values of S. 11. While TVO methods require slight overhead compared with IWAE, our adaptive moments scheduler does not re- Evaluating Scheduling Strategies In the Fig. 14-18 be- quire significantly more computation than the log-uniform low, we reproduce the setting of Fig. 8 to evaluate our baseline. scheduling strategies by K, for TVO with both REINFORCE and reparameterized gradient estimators, on Omniglot and Grid Search Comparison We evaluate our moments MNIST. We also report posterior inference results as mea- schedule with K =2against grid search over the choice sured by test DKL[q(z x) p✓(z x)]. In general, we find of a single intermediate 1 in Fig. 12. The setup is similar | || | comparable performance between our moments schedule to that of Fig. 1 on Omniglot (see Sec. 8), but here we and the log-uniform baseline, although our approach per- use reparameterization gradients instead of the original TVO. forms best with K =2and does not require grid search. Here, we train for 1000 epochs using an Adam optimizer Further, on MNIST with batch size 1000 and low K, log- with learning rate 10 3 and batch size 100. uniform, linear, and coarse-grained schedules suffer from We again find that our moments spacing schedule arrives poor performance due to instability in training, which is at an optimal choice of 1, and can even outperform the avoided by our moments schedule. Training can be stabi- best static value due to its ability to adaptively update at lized by using smaller batch sizes as in Fig. 12. each epoch. It is interesting to note that the final choice of Bregman Duality in Thermodynamic Variational Inference

Figure 11. Omniglot Runtimes (S =50, K =5, 5k epochs) Figure 12. MNIST K =2, with reparameterization gradients.

(a) MNIST Test log p✓(x) (b) MNIST Test DKL[q(z x) p✓(z x)] | || | Figure 13. Model Learning and Inference by S (with K =5)

(a) MNIST Test log p✓(x) (b) MNIST Test DKL[q(z x) p✓(z x)] | || | Figure 14. TVO with REINFORCE Gradients: Model Learning and Inference by K (with S =50) Bregman Duality in Thermodynamic Variational Inference

(a) MNIST Test log p✓(x) (b) MNIST Test DKL[q(z x) p✓(z x)] | || | Figure 15. TVO with Reparameterized Gradients: Model Learning and Inference by K (with S =50)

(a) Omniglot Test log p✓(x) (b) Omniglot Test DKL[q(z x) p✓(z x)] | || | Figure 17. TVO with REINFORCE Gradients: Model Learning and Inference by K (with S =50)

(a) Omniglot Test log p✓(x) (b) Omniglot Test DKL[q(z x) p✓(z x)] | || | Figure 18. TVO with Reparameterized Gradients: Model Learning and Inference by K (with S =50) Bregman Duality in Thermodynamic Variational Inference

I. Reparameterization Gradients for the TVO Integrand

Recall that the TVO objective involves terms of the form

1 q(z x) p✓(x, z) p✓(x, z) E⇡ [f(z)] where ⇡(z x)= | and f(z) = log (69) | Z q (z x) | While Masrani et al. (2019) derive a REINFORCE-style gradient estimator for the TVO, we seek to apply the reparameterization trick when possible, and thus differentiate with respect to only the inference network parameters . Note that, for z q (z x) reparameterizable with z = z(✏,) and ✏ p(✏), any expectation under ⇡ can be written as i ⇠ | i ⇠

1 1 p✓(x, z) E⇡ [f(z)] = Eq(z x) w f(z) = E✏ w f(z) where w = (70) Z | Z q (z x) | ⇥ ⇤ ⇥ ⇤ d p✓ (x,z) In differentiating (70), we will frequently encounter terms of the form E⇡ f(z) d log q (z x) for generic f(z). Noting | that the total derivative contains score function partial , we apply theh reparameterizationi trick to these terms in an approach similar to the ‘doubly-reparameterized’ estimator of Tucker et al. (2018). The following lemma summarizes these calculations, rewritten using expectations under ⇡ as in (70).

M p✓ (x,z) Lemma 1. Let f(z):R R, ⇡(z x), and w = q (z x) all depend on . When z q(z x) is reparameterizable 7! | | ⇠ | via z = z(✏,),✏ p(✏), the following identity holds for expectations under ⇡ ⇠ d @ z @ log w @f(z) E⇡ f(z) log w = E⇡ (1 )f(z) . (71) d @ @ z @ z   ✓ ◆

Proof. See Appendix I.3.

Corollary 1.1. For the choice of f(z)=1we obtain

d @ z @ log w E⇡ log w =(1 ) E⇡ . (72) d @ @ z  

The following lemma will allow us to apply reparameterization within the normalization constant. 1 Lemma 2. Let the same conditions hold as in Lemma 1, with Z = q(z x) p✓(x, z) dz. Then | R d @ z @ log w Z = (1 ) E✏ w . (73) d @ @ z 

Proof. See Appendix I.4.

We now proceed to differentiate the TVO integrand given by (69).

I.1. Reparameterized TVO Gradient Estimator M For generic f(z):R R and reparameterizable z q(z x) as above, the gradient with respect to can be written as 7! ⇠ | d d @ z @f(z) @ z @ log w E⇡ [f(z)] = E⇡ f(z) + (1 )Cov⇡ f(z), . (74) d d @ @ z @ @ z ✓ ◆ ✓ ◆ 

p✓ (x,z) The gradient of the TVO integrand is of particular interest. For f(z) = log w with w = q (z x) ,(74) simplifies to | d @ z @ log w @ z @ log w E⇡ [log w]=(1 2)E⇡ + (1 )Cov⇡ log w, . (75) d @ @ z @ @ z   Bregman Duality in Thermodynamic Variational Inference

Proof. We track changes between lines in blue, and begin by applying the .

d d 1 E⇡ [f(z)] = Z E✏ w f(z) (76) d d ⇣ ⌘ d 1 ⇥ ⇤ 1 d 1 d = Z E✏ w f(z) + Z E✏ f(z) w + Z E✏ w f(z) (77) d d d ✓ ◆  ✓ ◆  ✓ ◆ ⇥ ⇤ d 1 1 d 1 d = Z E✏ w f(z) + Z E✏ w f(z) log w + Z E✏ w f(z) (78) d Z2 d d ✓ ◆ !  ✓ ◆  ✓ ◆ ⇥ ⇤ d 1 d d = Z E⇡ [f(z)] + E⇡ f(z) log w +E⇡ f(z) (79) d Z d d ✓ ◆ ✓ ◆   1 2 We proceed to simplify| only the{z first two terms,} applying| Lemma{z 2 to 1 }and Lemma 1 to 2 .

@ z @ log w 1 @ z @ log w @ z @f(z) 1 + 2 = (1 )E✏ w E⇡ [f(z)] + (1 )E⇡ f(z) E⇡ @ @ z Z @ @ z @ @ z  ✓ ◆ ✓   ◆ Lemma 2 Lemma 1 (80) | {z } | {z } @ z @ log w @ z @ log w @ z @f(z) =(1 )E⇡ ( 1)E⇡ [f(z)] +(1 )E⇡ f(z) E⇡ (81) @ @ z @ @ z @ @ z    @ z @ log w @ z @ log w @ z @f(z) =(1 ) E⇡ f(z) E⇡ E⇡ [f(z)] E⇡ (82) @ @ z @ @ z @ @ z ✓   ◆  @ z @ log w @ z @f(z) =(1 ) Cov⇡ f(z), E⇡ . (83) @ @ z @ @ z ✓  ◆  By plugging (83) back into (79) we arrive at the reparameterized gradient for general f(z) (74). d @ z @ log w @ z @f(z) d E⇡ [f(z)] = (1 ) Cov⇡ f(z), E⇡ + E⇡ f(z) (84) d @ @ z @ @ z d ✓  ◆   d @ z @f(z) @ z @ log w = E⇡ f(z) + (1 )Cov⇡ f(z), . (85) d @ @ z @ @ z ✓ ◆ ✓ ◆ 

Finally, to optimize the TVO integrand, we can substitute f(z) = log w for various terms in (85). We then use Corollary 1.1 to apply the reparameterization trick within the total derivative in the first term. d d @ z @log w @ z @ log w E⇡ [log w]=E⇡ log w + (1 )Cov⇡ log w, (86) d d @ @ z @ @ z ✓ ◆ ✓ ◆  @ z @ log w @ z @ log w @ z @ log w = E⇡ (1 ) + (1 )Cov⇡ log w, (87) @ @ z @ @ z @ @ z  ✓ ◆ ✓ ◆  Corollary 1.1 | @ {zz @ log w } @ z @ log w = (1 2)E⇡ + (1 )Cov⇡ log w, (88) @ @ z @ @ z   This establishes (75) and is the expression that we use to optimize the TVO with reparameterization in the main text.

I.2. REPARAM / REINFORCE Equivalence for ⇡

It is well known (Tucker et al., 2018) that the reparameterization trick and REINFORCE estimator are equivalent for expectations under q (z x), which allows us to trade high variance REINFORCE gradients for reparameterization gradients | which directly consider derivatives of the function f(z). @ @ z @f(z) Eq(z x) f(z) log q(z x) = E✏ . (89) | @ | @ @ z   Bregman Duality in Thermodynamic Variational Inference

We use this equivalence to show a similar result for expectations under ⇡, which we will then use in the proofs of Lemma 1 in I.3 and Lemma 2 in I.4. Lemma 3. Let the same conditions hold as in Lemma 1. Then

@ @ z @f(z) @ log w E⇡ f(z) log q(z x) = E⇡ + f(z) . (90) @ | @ @ z @ z   ✓ ◆

Proof.

@ 1 @ E⇡ f(z) log q(z x) = Eq(z x) w f(z) log q(z x) Using (70) (91) @ | Z | @ |   1 @ z @(wf(z)) = E✏ Using (89) (92) Z @ @ z  1 @ z @f(z) @w = E✏ w + f(z) (93) Z @ @ z @ z  ✓ ◆ 1 @ z @f(z) @ log w = E✏ w + f(z)w (94) Z @ @ z @ z  ✓ ◆ 1 @ z @f(z) @ log w = E✏ w + f(z) (95) Z @ @ z @ z  ✓ ◆ @ z @f(z) @ log w = E⇡ + f(z) Using (70) (96) @ @ z @ z  ✓ ◆

I.3. Proof of Lemma 1

d @ z @ log w @f(z) E⇡ f(z) log w = E⇡ (1 )f(z) . (97) d @ @ z @ z   ✓ ◆

Proof. Using the fact that @ log w = @ log q (z x), | d @ log w @ z @ log w E⇡ f(z) log w = E⇡ f(z) + (98) d @ @ @   ✓ ◆ @ log q(z x) @ z @ log w = E⇡ f(z) | + (99) @ @ @  ✓ ◆ @ log q(z x) @ z @ log w = E⇡ f(z) | f(z) (100) @ @ @ ✓ ◆ ✓ ◆ @ z @f(z) @ log w @ z @ log w = E⇡ + f(z) f(z) Using Lemma 3 (101) @ @ z @ z @ @  ✓ ◆ ✓ ◆ @ z @f(z) @ log w @ log w = E⇡ + f(z) f(z) (102) @ @ z @ z @  ✓ ◆ @ z @f(z) @ log w = E⇡ + ( 1)f(z) (103) @ @ z @ z  ✓ ◆ @ z @ log w @f(z) = E⇡ (1 )f(z) . (104) @ @ z @ z  ✓ ◆ Bregman Duality in Thermodynamic Variational Inference

I.4. Proof of Lemma 2

d @ z @ log w Z = (1 )E✏ w . (105) d @ @ z 

1 Proof. Noting that we can use reparameterization inside the integral Z = q(z x) p✓(x, z) dz = Eq [w ]= | [w], we obtain E✏ R

d d Z = E✏[w ] (106) d d

d = E✏ w log w (107) d  d = ZE⇡ log w (108) d  @ z @ log w = (1 )ZE⇡ Using Corollary 1.1 (109) @ @ z  @ z @ log w = (1 )E✏ w (110) @ @ z 