A Simulated Annealing Based Inexact Oracle for Wasserstein Loss Minimization
Total Page:16
File Type:pdf, Size:1020Kb
A Simulated Annealing Based Inexact Oracle for Wasserstein Loss Minimization Jianbo Ye 1 James Z. Wang 1 Jia Li 2 Abstract ryl(x; ·) is computed in O(m) time, m being the complex- ity of outcome variables x or y. This part of calculation is Learning under a Wasserstein loss, a.k.a. Wasser- often negligible compared with the calculation of full gra- stein loss minimization (WLM), is an emerging dient with respect to the model parameters. But this is no research topic for gaining insights from a large longer the case in learning problems based on Wasserstein set of structured objects. Despite being concep- distance due to the intrinsic complexity of the distance. tually simple, WLM problems are computation- We will call such problems Wasserstein loss minimization ally challenging because they involve minimiz- (WLM). Examples of WLMs include Wasserstein barycen- ing over functions of quantities (i.e. Wasserstein ters (Li & Wang, 2008; Agueh & Carlier, 2011; Cuturi & distances) that themselves require numerical al- Doucet, 2014; Benamou et al., 2015; Ye & Li, 2014; Ye gorithms to compute. In this paper, we intro- et al., 2017b), principal geodesics (Seguy & Cuturi, 2015), duce a stochastic approach based on simulated nonnegative matrix factorization (Rolet et al., 2016; San- annealing for solving WLMs. Particularly, we dler & Lindenbaum, 2009), barycentric coordinate (Bon- have developed a Gibbs sampler to approximate neel et al., 2016), and multi-label classification (Frogner effectively and efficiently the partial gradients of et al., 2015). a sequence of Wasserstein losses. Our new ap- proach has the advantages of numerical stability Wasserstein distance is defined as the cost of matching two and readiness for warm starts. These character- probability measures, originated from the literature of op- istics are valuable for WLM problems that of- timal transport (OT) (Monge, 1781). It takes into account ten require multiple levels of iterations in which the cross-term similarity between different support points the oracle for computing the value and gradient of the distributions, a level of complexity beyond the usual of a loss function is embedded. We applied the vector data treatment, i.e., to convert the distribution into a method to optimal transport with Coulomb cost vector of frequencies. It has been promoted for comparing and the Wasserstein non-negative matrix factor- sets of vectors (e.g. bag-of-words models) by researchers ization problem, and made comparisons with the in computer vision, multimedia and more recently natural existing method of entropy regularization. language processing (Kusner et al., 2015; Ye et al., 2017a). However, its potential as a powerful loss function for ma- chine learning has been underexplored. The major obstacle 1. Introduction is a lack of standardized and robust numerical methods to solve WLMs. Even to empirically better understand the An oracle is a computational module in an optimization advantages of the distance is of interest. arXiv:1608.03859v4 [stat.CO] 6 Jun 2017 procedure that is applied iteratively to obtain certain char- As a long-standing consensus, solving WLMs is challeng- acteristics of the function being optimized. Typically, it ing (Cuturi & Doucet, 2014). Unlike the usual optimiza- calculates the value and gradient of loss function l(x; y). In tion in machine learning where the loss and the (partial) the vast majority of machine learning models, where those gradient can be calculated in linear time, these quantities loss functions are decomposable along each dimension are non-smooth and hard to obtain in WLMs, requiring so- (e.g., L norm, KL divergence, or hinge loss), r l(·; y) or p x lution of a costly network transportation problem (a.k.a. 3 1College of Information Sciences and Technology, The Penn- OT). The time complexity, O(m log m), is prohibitively 2 sylvania State University, University Park, PA. Department of high (Orlin, 1993). In contrast to the Lp or KL counter- Statistics, The Pennsylvania State University, University Park, parts, this step of calculation elevates from a negligible PA.. Correspondence to: Jianbo Ye <[email protected]>. fraction of the overall learning problem to a dominant por- Proceedings of the 34 th International Conference on Machine tion, preventing the scaling of WLMs to large data. Re- Learning, Sydney, Australia, PMLR 70, 2017. Copyright 2017 cently, iterative approximation techniques have been devel- by the author(s). oped to compute the loss and the (partial) gradient at com- A Simulated Annealing Based Inexact Oracle for Wasserstein Loss Minimization plexity O(m2=") (Cuturi, 2013; Wang & Banerjee, 2014). lem with a strongly convex term such that the regularized However, nontrivial algorithmic efforts are needed to in- objective becomes a smooth function of all its coordinat- corporate these methods into WLMs because WLMs often ing parameters. Neither the Sinkhorn’s algorithm nor Breg- require multi-level loops (Cuturi & Doucet, 2014; Frogner man ADMM can be readily integrated into a general WLM. et al., 2015). Specifically, one must re-calculate through Based on the entropic regularization of primal OT, Cuturi & many iterations the loss and its partial gradient in order to Peyre´(2016) recently showed that the Legendre transform update other model dependent parameters. of the entropy regularized Wasserstein loss and its gradi- ent can be computed in closed form, which appear in the We are thus motivated to seek for a fast inexact oracle first-order condition of some complex WLM problems. Us- that (i) runs at lower time complexity per iteration, and ing this technique, the regularized primal problem can be (ii) accommodates warm starts and meaningful early stops. converted to an equivalent Fenchel-type dual problem that These two properties are equally important for efficiently has a faster numerical solver in the Euclidean space (Ro- obtaining adequate approximation to the solutions of a se- let et al., 2016). But this methodology can only be applied quence of slowly changing OTs. The second property en- to a certain class of WLM problems of which the Fenchel- sures that the subsequent OTs can effectively leverage the type dual has closed forms of objective and full gradient. solutions of the earlier OTs so that the total computational In contrast, the proposed SA-based approach directly deals time is low. Approximation techniques with low complex- with the dual OT problem without assuming any particular ity per iteration already exist for solving a single OT, but mathematical structure of the WLM problem, and hence is they do not possess the second property. In this paper, we more flexible to apply. introduce a method that uses a time-inhomogeneous Gibbs sampler as an inexact oracle for Wasserstein losses. The More recent approaches base on solving the dual OT prob- Markov chain Monte Carlo (MCMC) based method natu- lems have been proposed to calculate and optimize the rally satisfies the second property, as reflected by the in- Wasserstein distance between a single pair of distributions tuition of physicists that MCMC samples can efficiently with very large support sets — often as large as the size “remix from a previous equilibrium.” of an entire machine learning dataset (Montavon et al., 2016; Genevay et al., 2016; Arjovsky et al., 2017). For We propose a new optimization approach based on Sim- these methods, scalability is achieved in terms of the sup- ulated Annealing (SA) (Kirkpatrick et al., 1983; Corana port size. Our proposed method has a different focus on et al., 1987) for WLMs where the outcome variables are calculating and optimizing Wasserstein distances between treated as probability measures. SA is especially suitable many pairs all together in WLMs, with each distribution for the dual OT problem, where the usual Metropolis sam- having a moderate support size (e.g., dozens or hundreds). pler can be simplified to a Gibbs sampler. To our knowl- We aim at scalability for the scenarios when a large set of edge, existing optimization techniques used on WLMs are distributions have to be handled simultaneously, that is, the different from MCMC. In practice, MCMC is known to optimization cannot be decoupled on the distributions. In easily accommodate warm start, which is particularly use- addition, existing methods have no on-the-fly mechanism ful in the context of WLMs. We name this approach Gibbs- to control the approximation quality at a limited number of OT for short. The algorithm of Gibbs-OT is as simple iterations. and efficient as the Sinkhorn’s algorithm — a widely ac- cepted method to approximately solve OT (Cuturi, 2013). We show that Gibbs-OT enjoys improved numerical sta- 3. Preliminaries of Optimal Transport bility and several algorithmic characteristics valuable for In this section, we present notations, mathematical back- general WLMs. By experiments, we demonstrate the ef- grounds, and set up the problem of interest. fectiveness of Gibbs-OT for solving optimal transport with Coulomb cost (Benamou et al., 2016) and the Wasserstein Definition 3.1 (Optimal Transportation, OT). Let p 2 non-negative matrix factorization (NMF) problem (Sandler ∆m1 ; q 2 ∆m2 , where ∆m is the set of m-dimensional & Lindenbaum, 2009; Rolet et al., 2016). def. m 1 simplex: ∆m = fq 2 R+ : hq; i = 1g. The set of trans- portation plans between p and q is defined as Π(p; q) def.= 2. Related Work m1×m2 1 T 1 fZ 2 R : Z · m2 = p; Z · m1 = q; g. Let M 2 m1×m2 be the matrix of costs. The optimal trans- Recently, several methods have been proposed to overcome R+ port cost between p and q with respect to M is the aforementioned difficulties in solving WLMs. Rep- resentatives include entropic regularization (Cuturi, 2013; def. Cuturi & Doucet, 2014; Benamou et al., 2015) and Breg- W (p; q) = min hZ; Mi : (1) Z2Π(p;q) man ADMM (Wang & Banerjee, 2014; Ye et al., 2017b).