Incorporating Nesterov Momentum Into Adam
Total Page:16
File Type:pdf, Size:1020Kb
Incorporating Nesterov Momentum into Adam Timothy Dozat 1 Introduction Algorithm 1 Gradient Descent gt ∇q f (q t−1) When attempting to improve the performance of a t−1 q t q t−1 − hgt deep learning system, there are more or less three approaches one can take: the first is to improve the structure of the model, perhaps adding another layer, sum (with decay constant m) of the previous gradi- switching from simple recurrent units to LSTM cells ents into a momentum vector m, and using that in- [4], or–in the realm of NLP–taking advantage of stead of the true gradient. This has the advantage of syntactic parses (e.g. as in [13, et seq.]); another ap- accelerating gradient descent learning along dimen- proach is to improve the initialization of the model, sions where the gradient remains relatively consis- guaranteeing that the early-stage gradients have cer- tent across training steps and slowing it along turbu- tain beneficial properties [3], or building in large lent dimensions where the gradient is significantly amounts of sparsity [6], or taking advantage of prin- oscillating. ciples of linear algebra [15]; the final approach is to try a more powerful learning algorithm, such as in- Algorithm 2 Classical Momentum cluding a decaying sum over the previous gradients gt ∇qt−1 f (q t−1) in the update [12], by dividing each parameter up- mt mmt−1 + gt date by the L2 norm of the previous updates for that q t q t−1 − hmt parameter [2], or even by foregoing first-order algo- rithms for more powerful but more computationally [14] show that Nesterov’s accelerated gradient costly second order algorithms [9]. This paper has (NAG) [11]–which has a provably better bound than as its goal the third option—improving the quality gradient descent–can be rewritten as a kind of im- of the final solution by using a faster, more powerful proved momentum. If we can substitute the defini- learning algorithm. tion for mt in place of the symbol mt in the parame- ter update as in (2) 2 Related Work q q − hm (1) 2.1 Momentum-based algorithms t t−1 t q t q t−1 − hmmt−1 − hgt (2) Gradient descent is a simple, well-known, and gen- erally very robust optimization algorithm where the we can see that the term mt−1 doesn’t depend on gradient of the function to be minimized with re- the current gradient gt —so in principle, we can get spect to the parameters (∇ f (q t−1)) is computed, and a superior step direction by applying the momentum a portion h of that gradient is subtracted off of the vector to the parameters before computing the gradi- parameters: ent. The authors provide empirical evidence that this Classical momentum [12] accumulates a decaying algorithm is superior to the gradient descent, classi- Algorithm 3 Nesterov’s accelerated gradient with adaptive moment estimation (Adam), combin- gt ∇qt−1 f (q t−1 − hmmt−1) ing classical momentum (using a decaying mean in- mt mmt−1 + gt stead of a decaying sum) with RMSProp to improve q t q t−1 − hmt performance on a number of benchmarks. In their algorithm, they include initialization bias correction terms, which offset some of the instability that ini- cal momentum, and Hessian-Free [9] algorithms for tializing m and n to 0 can create. conventionally difficult optimization objectives. Algorithm 6 Adam 2.2 L2 norm-based algorithms gt ∇qt−1 f (q t−1) [2] present adaptive subgradient descent (Ada- mt mmt−1 + (1 − m)gt L mt Grad), which divides h of every step by the 2 norm mˆ t 1−mt of all previous gradients; this slows down learning 2 nt nnt−1 + (1 − n)gt along dimensions that have already changed signifi- nt nˆt 1−nt cantly and speeds up learning along dimensions that mˆ t q t q t−1 − h p have only changed slightly, stabilizing the model’s nˆt +e representation of common features and allowing it to rapidly “catch up” its representation of rare fea- [5] also include an algorithm AdaMax that re- 1 tures. places the L2 norm with the L¥ norm, removing the need for nˆt and replacing nt and q t with the follow- Algorithm 4 AdaGrad ing updates: gt ∇ f (q t− ) qt−1 1 nt max(nnt− ;jgt j) 2 1 nt nt−1 + g mˆ t t q t q t−1 − h gt nt +e q t q t−1 − h p nt +e We can generalize this to RMSProp as well, using the L¥ norm in the denominator instead of the L2 One notable problem with AdaGrad is that the norm, giving what might be called the MaxaProp norm vector n eventually becomes so large that algorithm. training slows to a halt, preventing the model from 3 Methods reaching the local minimum; [16] go on to motivate RMSProp, an alternative to AdaGrad that replaces 3.1 NAG revisited the sum in n with a decaying mean parameterized t Adam combines RMSProp with classical momen- here by n. This allows the model to continue to learn tum. But as [14] show, NAG is in general supe- indefinitely. rior to classical momentum—so how would we go Algorithm 5 RMSProp about modifying Adam to use NAG instead? First, we rewrite the NAG algorithm to be more straight- gt ∇qt−1 f (q t−1) 2 forward and efficient to implement at the cost of nt nnt−1 + (1 − n)gt gt some intuitive readability. Momentum is most ef- q t q t−1 − h p nt +e fective with a warming schedule, so for complete- ness we parameterize m by t as well. Here, the 2.3 Combination Algorithm 7 NAG rewritten One might ask if combining the momentum-based gt ∇qt−1 f (q t−1) and norm-based methods might provide the ad- mt mt mt−1 + gt vantages of both. In fact, [5] successfully do so m¯ t gt + mt+1mt q t q t−1 − hm¯ t 1Most implementations of this kind of algorithm include an e parameter to keep the denominator from being too small and resulting in an irrecoverably large step vector m¯ contains the gradient update for the cur- rent timestep gt in addition to the momentum vector Algorithm 8 Nesterov-accelerated adaptive moment update for the next timestep mt+1mt , which needs estimation to be applied before taking the gradient at the next gt ∇qt−1 f (q t−1) gt timestep. We don’t need to apply the momentum gˆ t 1−∏i=1 mi vector for the current timestep anymore because we mt mmt−1 + (1 − m)gt mt already applied it in the last update of the parame- mˆ t t+1 1−∏i=1 mi ters, at timestep t − 1. 2 nt nnt−1 + (1 − n)gt nˆ nt 3.2 Applying NAG to Adam t 1−nt m¯ t (1 − mt )gˆt + mt+1mˆ t Ignoring the initialization bias correction terms for m¯ t q t q t−1 − h p the moment, Adam’s update rule can be written in nˆt +e terms of the previous momentum/norm vectors and current gradient update as in (3). 4 Experiments mmt−1 q t q t−1 − h p (3) To test this algorithm, we compared the perfor- nn + (1 − n)g2 + e t−1 t mance of nine algorithms–GD, Momentum, NAG, (1 − m)g − h t RMSProp, Adam, Nadam, MaxaProp, AdaMax, p 2 nnt−1 + (1 − n)gt + e and Nadamax–on three benchmarks–word2vec [10], MNIST image classification [7], and LSTM lan- In rewritten NAG, we would take the first part of the guage models [17]. All algorithms used n = :999 step and apply it before taking the gradient of the and e = 1e−8 as suggested in [5], with a momen- cost function f –however, the denominator depends t tum schedule given by mt = m(1 − :5 × :96 250 ) with on gt , so we can’t take advantage of the trick used m = :99, similar to the recommendation in [14]. in NAG for this equation. However, n is generally Only the learning rate h differed across algorithms chosen to be very large (normally > :9), so the dif- and experiments. The algorithms were all coded us- ference between nt−1 and nt will in general be very ing Google’s TensorFlow [1] API and the experi- small. We can then replace nt with nt−1 without los- ments were done using the built-in TensorFlow mod- ing too much accuracy: els, making only small edits to the default settings. mmt−1 All algorithms used initialization bias correction. q t q t−1 − h p (4) nt− + e 1 4.1 Word2Vec (1 − m)g − h t p 2 Word2vec [10] word embeddings were trained us- nnt− + (1 − n)g + e 1 t ing each of the nine algorithms. Approximately The the first term in the expression in (4) no longer 100MB of cleaned text2 from Wikipedia were used depends on gt , meaning here we can use the Nes- as the source text, and any word not in the top 50000 terov trick; this give us the following expressions for words was replaced with UNK. 128-dimensional vec- m¯ t and q t : tors with a left and right context size of 1 were trained using noise-contrastive estimation with 64 m¯ t (1 − mt )gt + mt+1mt m¯ t negative samples. Validation was done using the q t q t−1 − h p vt +e word analogy task; we report the average cosine dif- All that’s left is to determine how to include the ferene ( 1−cossim(x;y) ) between the analogy vector and initialization bias correction terms, taking into con- 2 the embedding of the correct answer to the analogy.