Dynamics of Deep Neural Networks and Neural Tangent Hierarchy

Jiaoyang Huang * 1 Horng-Tzer Yau * 2

Abstract problems, there is no guarantee that a gradient based algo- The evolution of a deep neural network trained by rithm will be able to find the optimal parameters efficiently the in the overparametrization during the training of neural networks. One question then regime can be described by its neural tangent ker- arises: given such complexities, is it possible to obtain a nel (NTK) (Jacot et al., 2018; Du et al., 2018b;a; succinct description of the training dynamics? Arora et al., 2019b). It was observed (Arora et al., In this paper, we focus on the empirical risk minimization 2019a) that there is a performance gap between problem with the quadratic the kernel regression using the limiting NTK and n the deep neural networks. We study the dynamic 1 X 2 min L(θ) = (f(xα, θ) − yα) of neural networks of finite width and derive an θ 2n infinite hierarchy of differential equations, the α=1 neural tangent hierarchy (NTH). We prove that n n where {xα}α=1 are the training inputs, {yα}α=1 are the the NTH hierarchy truncated at the level p > 2 labels, and the dependence is modeled by a deep fully- approximates the dynamic of the NTK up to ar- connected feedforward neural network with H hidden layers. bitrary precision under certain conditions on the The network has d input nodes, and the input vector is given neural network width and the data set dimension. d by x ∈ R . For 1 6 ` 6 H, the `-th hidden layer has The assumptions needed for these approximations m neurons. Let x(`) be the output of the `-th layer with become weaker as p increases. Finally, NTH can x(0) = x. Then the feedforward neural network is given by be viewed as higher order extensions of NTK. In the set of recursive equations: particular, the NTH truncated at p = 2 recovers the NTK dynamics. 1 x(`) = √ σ(W (`)x(`−1)), ` = 1, 2, ··· ,H, (1) m

(`) m×d (`) m×m 1. Introduction where W ∈ R if ` = 1 and W ∈ R if 2 6 ` 6 H are the weight matrices, and σ is the activation unit, Deep neural networks have become popular due to their un- which is applied coordinate-wise to its input. The output of precedented success in a variety of machine learning tasks. the neural network is Image recognition (LeCun et al., 1998; Krizhevsky et al., > (H) 2012; Szegedy et al., 2015), speech recognition (Hinton f(x, θ) = a x ∈ R, (2) et al., 2012; Sainath et al., 2013), playing Go (Silver et al., 2016; 2017) and natural language understanding (Collobert where a ∈ Rm is the weight matrix for the output layer. et al., 2011; Wu et al., 2016; Devlin et al., 2018) are just We denote the vector containing all trainable parameters a few of the recent achievements. However, one aspect of by θ = (vec(W (1)), vec(W (2)) ..., vec(W (H)), a). We deep neural networks that is not well understood is train- remark that√ this parametrization is nonstandard because of ing. Training a deep neural network is usually done via a those 1/ m factors. However, it has already been adopted gradient decent based algorithm. Analyzing such training in several recent works (Jacot et al., 2018; Du et al., 2018b;a; dynamics is challenging. Firstly, as highly nonlinear struc- Lee et al., 2019). We note that the predictions and training tures, deep neural networks usually involve a large number dynamics of (1) are identical√ to those of standard networks, of parameters. Secondly, as highly non-convex optimization up to a scaling factor 1/ m in the learning rate for each parameter. *Equal contribution 1School of Mathematics, IAS, Princeton, NJ, USA 2Mathematics Department, Harvard, Cambridge, MA, We initialize the neural network with random Gaussian USA. Correspondence to: Jiaoyang Huang . weights following the Xavier initialization scheme (Glo- rot & Bengio, 2010). More precisely, we set the initial Proceedings of the 37 th International Conference on Machine (`) 2 2 Learning, Online, PMLR 119, 2020. Copyright 2020 by the au- parameter vector θ0 as Wij ∼ N (0, σw), ai ∼ N (0, σa). thor(s). In this way, for the randomly initialized neural network, Dynamics of Deep Neural Networks and Neural Tangent Hierarchy we have that the L2 norms of the output of each layer are and for 1 6 ` 6 H, (`) 2 of order one, i.e. kx k2 = O(1) for 0 6 ` 6 H, and (`) G (xα, xβ) = h∂ (`) f(xα, θt), ∂ (`) f(xβ, θt)i f(x, θ0) = O(1) with high probability. In this paper, we t W W * train all layers of the neural network with continuous time (W (`+1))> a = σ0 (x ) t√ ··· σ0 (x )√ t , gradient descent (gradient flow): for any time t > 0 ` α m H α m

(`) (`+1) > + ∂tWt = −∂W (`) L(θt), ` = 1, 2, ··· ,H, 0 (Wt ) 0 at (`−1) (`−1) (3) σ (xβ) √ ··· σ (xβ)√ hx , x i, ` m H m α β ∂tat = −∂aL(θt),

(1) (2) (H) and where θt = (vec(W ), vec(W ) ..., vec(W ), at). t t t (H+1) (H) (H) Gt = h∂af(xα, θt), ∂af(xβ, θt)i = hxα , x i. For simplicity of notations, we write σ(W (`)x(`−1)) as β σ (x), or simply σ if the context is clear. We write its (2) ` ` The NTK Kt (·, ·) varies along training. However, in the 0 (`) (`−1) 0 (1) derivative diag(σ (W x )) as σ`(x) = σ` (x), and infinite width limit, the training dynamic is very simple: The (r) (`) (`−1) (r) (r) (2) (2) r-th derivative diag(σ (W x )) as σ` (x), or σ` NTK does not change along training, Kt (·, ·) = K∞ (·, ·). (r) The network function f(x, θt) follows a linear differential for r > 1. In this notation, σ` (x) are diagonal matri- ces. With those notations, explicitly, the continuous time equation (Jacot et al., 2018): gradient descent dynamic (3) is n 1 X ∂ f(x, θ ) = − K(2)(x, x )(f(x , θ ) − y ), (8) t t n ∞ β β t β (`) β=1 ∂tWt = −∂W (`) L(θt) n (`+1) > ! which becomes analytically tractable. In other words, the 1 X 0 (Wt ) 0 at = − σ (xβ) √ ··· σ (xβ)√ (4) n ` m H m training dynamic is equivalent to the kernel regression using β=1 (2) the limiting NTK K∞ (·, ·). While the linearization (8) is (`−1) > only exact in the infinite width limit, for a sufficiently wide ⊗ (xβ ) (f(xβ, θt) − yβ), deep neural network, (8) still provides a good approxima- for ` = 1, 2, ··· ,H, and tion of the learning dynamic for the corresponding deep neural network (Du et al., 2018b;a; Lee et al., 2019). As a n 1 X (H) consequence, it was proven in (Du et al., 2018b;a) that, for ∂ a = −∂ L(θ ) = − x (f(x , θ ) − y ). (5) 4 t t a t n β β t β a fully-connected wide neural network with m & n under β=1 certain assumptions on the data set, the gradient descent converges to zero training loss at a linear rate. Although 1.1. Neural Tangent Kernel highly overparametrized neural networks is equivalent to the kernel regression, it is possible to show that the class A recent paper (Jacot et al., 2018) introduced the Neural of finite width neural networks is more expressive than the Tangent Kernel (NTK) and proved the limiting NTK cap- limiting NTK. It has been constructed in (Ghorbani et al., tures the behavior of fully-connected deep neural networks 2019; Yehudai & Shamir, 2019; Allen-Zhu & Li, 2019) that in the infinite width limit trained by gradient descent: there are simple functions that can be efficiently learnt by finite width neural networks, but not the kernel regression ∂tf(x, θt) = ∂θf(x, θt)∂tθt = −∂θf(x, θt)∂θL(θt) using the limiting NTK. n 1 X = − ∂ f(x, θ ) ∂ f(x , θ )(f(x , θ ) − y ) n θ t θ β t β t β 1.2. Contribution β=1 n There is a performance gap between the kernel regression 1 X (2) = − K (x, xβ)(f(xβ, θt) − yβ), (8) using the limiting NTK and the deep neural networks. It n t β=1 was observed in (Arora et al., 2019a) that the convolutional (6) neural networks outperform their corresponding limiting NTK by 5% - 6%. This performance gap is likely to origi- (2) where the NTK Kt (·, ·) is given by nate from the change of the NTK along training due to the finite width effect. The change of the NTK along training (2) has its benefits on generalization. Kt (xα, xβ) = h∂θf(xα, θt), ∂θf(xβ, θt)i H+1 In the current paper, we study the dynamic of the NTK for X (`) (7) = Gt (xα, xβ), finite width deep fully-connected neural networks. Here we `=1 summarize our main contributions: Dynamics of Deep Neural Networks and Neural Tangent Hierarchy

• We show the gradient descent dynamic is captured by 1.4. Related Work an infinite hierarchy of ordinary differential equations, In this section, we survey an incomplete list of previous the neural tangent hierarchy (NTH). Similar recursive works on optimization aspect of deep neural networks. differential equations were also obtained by Dyer and Gur-Ari (Dyer & Gur-Ari, 2019). Different from the Because of the highly non-convexity nature of deep neural limiting NTK (7), which depends only on the neural networks, the gradient based algorithms can potentially get network architecture, the NTH is data dependent and stuck near a critical point, i.e., saddle point or local mini- capable of learning data-dependent features. mum. So one important question in deep neural networks is: what does the loss landscape look like. One promising candidate for loss landscapes is the class of functions that • We derive a priori estimates of the higher order kernels satisfy: (i) all local minima are global minima and (ii) there involved in the NTH. Using these a priori estimates as exists a negative curvature for every saddle point. A line of input, we confirm a numerical observation in (Lee et al., recent results show that, in many optimization problems of 2019) that the NTK varies at a rate of order O(1/m). interest (Ge et al., 2015; 2016; Sun et al., 2018; 2016; Bho- As a corollary, this implies that for a fully-connected janapalli et al., 2016; Park et al., 2016), loss landscapes are wide neural network with m n3, the gradient descent & in such class. For this function class, (perturbed) gradient converges to zero training loss at a linear rate, which descent (Jin et al., 2017; Ge et al., 2015; Lee et al., 2016) improves the results in (Du et al., 2018a). can find a global minimum. However, even for a three- layer linear network, there exists a saddle point that does • The NTH is just an infinite sequence of relationship. not have a negative curvature (Kawaguchi, 2016). So it is Without truncation, it cannot be used to determine the unclear whether this geometry-based approach can be used dynamic of the NTK. Using the a priori estimates of to obtain the global convergence guarantee of first-order the higher order kernels as input, we construct a trun- methods. Another approach is to show that practical deep cated hierarchy of ordinary differential equations, the neural networks allow some additional structure or assump- truncated NTH. We show that this system of truncated tion to make non-convex optimizations tractable. Under equations approximates the dynamic of the NTK to certain simplification assumptions, it has been proven re- certain time up to arbitrary precision. This description cently that there are novel loss landscape structures in deep makes it possible to directly study the change of the neural networks, which may play a role in making the opti- NTK for deep neural networks. mization tractable (Dauphin et al., 2014; Choromanska et al., 2015; Kawaguchi, 2016; Liang et al., 2018; Kawaguchi & Kaelbling, 2019). 1.3. Notations Recently, it was proved in a series of papers that, if the ∗ In the paper, we fix a large constant p > 0, which appears size of a neural network is significantly larger than the size in Assumptions (2.1) and (2.2). We use c, C to represent of the dataset, the (stochastic) gradient descent algorithm universal constants, which might be different from line to can find optimal parameters (Li & Liang, 2018; Du et al., line. In the paper, we write a = O(b) or a . b if there exists 2018b; Song & Yang, 2019; Du et al., 2018a; Allen-Zhu some large universal constant C such that |a| 6 Cb. We write et al., 2018b; Zou et al., 2018; Zou & Gu, 2019). In the a & b if there exists some small universal constant c > 0 overparametrization regime, a fully-trained deep neural net- such that a > cb. We write a  b if there exist universal work is indeed equivalent to the kernel regression predictor constants c, C such that cb 6 |a| 6 Cb. We reserve n for the using the limiting NTK (8). As a consequence, the gradi- number of input samples and m for the width of the neural ent descent achieves zero training loss for a deep overpa- network. For practical neural networks, we always have rameterized neural network. Under further assumptions, it that m . poly(n) and n . poly(m). We denote the set can be shown that the trained networks generalize (Arora of input samples as X = {x1, x2, ··· , xn}. For simplicity et al., 2019b; Allen-Zhu et al., 2018a; Cao & Gu, 2019). of notations, we write the output of the neural network as Unfortunately, there is a significant gap between the over- fβ(t) = f(xβ, θt). We denote vector L2 norm as k · k2, parametrized neural networks, which are provably trainable, vector or function L∞ norm as k · k∞, matrix spectral norm and neural networks in common practice. Typically, deep as k · k2→2, and matrix Frobenius norm as k · kF. We say neural networks used in practical applications are trainable, that an event holds with high probability, if it holds with and yet, much smaller than what the previous theories re- −mc probability at least 1 − e for some c > 0. Then the quire to ensure trainability. In (Kawaguchi & Huang, 2019), intersection of poly(n, m) many high probability events is it is proven that gradient descent can find a global mini- still a high probability event, provided m is large enough. mum for certain deep neural networks of sizes commonly In the paper, we treat cr, Cr in Assumption 2.1 and 2.2, and encountered in practice. the depth H as constants. We will not keep track of them. Dynamics of Deep Neural Networks and Neural Tangent Hierarchy

Training dynamics of neural networks in the mean field set- and for any r > 2, ting have been studied in (Mei et al., 2019; Song et al., 2018; ∂ K(r)(x , x , ··· , x ) Araujo´ et al., 2019; Nguyen, 2019; Sirignano & Spiliopou- t t α1 α2 αr n los, 2019; Chizat & Bach, 2018). Their mean field analysis 1 X (r+1) = − K (x , x , ··· , x , x )(f (t) − y ). describes distributional dynamics of neural network parame- n t α1 α2 αr β β β ters via certain nonlinear partial differential equations, in the β=1 asymptotic regime of large network sizes and large number (10) of stochastic gradient descent training iterations. However, There exists a deterministic family (independent of m) of (r) r ∗ (r) their analysis is restricted to neural networks in the mean- operators K : X 7→ R for 2 6 r 6 p + 1 and K = 0 field framework√ with a normalization factor 1/m, different if r is odd, such that with high probability with respect to the from ours 1/ m, which is commonly used in modern net- random initialization, there exist some constants C, C0 > 0 works (Glorot & Bengio, 2010). such that K(r) (ln m)C K(r) − , 0 r/2−1 . (r−1)/2 (11) 2. Main results m ∞ m ∗ p 0 Assumption 2.1. The σ is smooth, and 2(p∗+1) C ∗ and for 0 6 t 6 m /(ln m) , for any 1 6 r 6 2p +1, there exists a constant Cr > 0 such (r) (ln m)C that the r-th derivative of σ satisfies kσ (x)k∞ 6 Cr. kK(r)k . (12) t ∞ . mr/2−1 Assumption 2.1 is satisfied by using common activation units such as sigmoid and hyperbolic tangents. More- It was proven in (Du et al., 2018a; Lee et al., 2019) that the change of the NTK for a wide deep neural network over, the softplus activation, which is defined as σa(x) = √ ln(1 + exp(ax))/a, satisfies Assumption 2.1 with any hy- is upper bounded by O(1/ m). However, the numerical experiments in (Lee et al., 2019) indicate the change of the perparameter a ∈ R>0. The softplus activation can approxi- mate the ReLU activation for any desired accuracy as NTK is closer to O(1/m). As a corollary of Theorem 2.3, we confirm the numerical observation that the NTK varies at a rate of order O(1/m). σa(x) → relu(x) as a → ∞, Corollary 2.4. Under Assumptions 2.1 and 2.2, the NTK (2) where relu represents the ReLU activation. Kt (·, ·) varies at a rate of order O(1/m): with high probability with respect to the random initialization, there Assumption 2.2. There exists a small constant c > 0 such exist some constants C, C0 > 0 such that for 0 t −1 ∗ 6 6 that the training inputs satisfy c < kx k c . For any p 0 α 2 6 2(p∗+1) C ∗ m /(ln m) , it holds 1 6 r 6 2p +1, there exists a constant cr > 0 such that for C any distinct indices 1 6 α1, α2, ··· , αr 6 n, the smallest (2) (1 + t)(ln m) k∂tKt k∞ . . singular value of the data matrix [xα1 , xα2 , ··· , xαr ] is at m least cr. As another corollary of Theorem 2.3, for a fully-connected 3 For more general input data, we can always normalize them wide neural network with m & n , the gradient descent −1 converges to zero training loss at a linear rate. such that c < kxαk2 6 c . Under this normalization, for the randomly initialized deep neural network, it holds Corollary 2.5. Under Assumptions 2.1 and 2.2, we further (`) assume that there exists λ > 0 (which might depend on n) that kxα k2 = O(1) for all 1 6 ` 6 H, where the im- plicit constants depend on `. The second part of Assump- h (2) i λmin K0 (xα, xβ) > λ, (13) tion 2.2 requires that for any small number of input data: 16α,β6n x , x , ··· , x , they are linearly independent. α1 α2 αr and the width m of the neural network satisfies Theorem 2.3. Under Assumptions 2.1 and 2.2, there exists 3 0 n C 2 (r) r m > C (ln m) ln(n/ε) , (14) an infinite family of operators Kt : X 7→ R for r > 2, λ the continuous time gradient descent dynamic is given by for some large constants C, C0 > 0. Then with high proba- an infinite hierarchy of ordinary differential equations, i.e., bility with respect to the random initialization, the training the NTH, error decays exponentially, n n 1 X λt X (2) 2 − 2n ∂t(fα(t) − yα) = − K (xα, xβ)(fβ(t) − yβ), (fβ(t) − yβ) . ne , n t β=1 β=1 (9) which reaches ε at time t  (n/λ) ln(n/ε). Dynamics of Deep Neural Networks and Neural Tangent Hierarchy

The assumption that there exists λ > 0 such that In the following theorem, we show this system of truncated h (2) i equations (15) approximates the dynamic of the NTK up to λmin K0 (xα, xβ) > λ appears in many pre- 16α,β6n arbitrary precision, provided that p is large enough. vious papers (Du et al., 2018b;a; Arora et al., 2019b). It is proven in (Du et al., 2018b;a), if no two inputs are paral- Theorem 2.6. Under Assumptions 2.1 and 2.2, we take an ∗ lel, the smallest eigenvalue of the kernel matrix is strictly even p 6 p and further assume that positive. In general λ might depend on the size of the train- h (2) i λmin K0 (xα, xβ) > λ. (16) ing data set n. For two layer neural networks with random 1 α,β n training data, quantitative estimates for λ are obtained in 6 6 (Ghorbani et al., 2019; Zhang et al., 2019; Xie et al., 2016). Then there exist constants c, C, C0 > 0 such that for t

It is proven that with high probability with respect to the ∗ p 0 β p C ∗ C random training data, λ & n for some 0 < β < 1/2. t 6 min{c λm/n/(ln m) , m 2(p +1) /(ln m) }, (17) (H) It is proven in (Du et al., 2018a) that if there exists λ > 0, the dynamic (9) can be approximated by the truncated dy- namic (15), h (H) i (H) λmin G0 (xα, xβ) > λ , 16α,β6n  1/2 n p−1√ X ˜ 2 (1 + t)t n n no (H) 4  (fβ(t) − fβ(t))  min t, , then for m > C(n/λ ) the gradient descent finds a . mp/2 λ global minimum. Corollary 13 improves this result in two β=1 ways: (i) We improve the quartic dependence of n to a cubic (18) K(2) = PH+1 G(`) dependence. (ii) We recall that t `=1 0 , and and (`) those kernels G0 are all non-negative definite. The small- (2) (2) (2) |K (xα, xβ) − K˜ (xα, xβ)| est eigenvalue of K0 is typically much bigger than that t t of G(H), i.e., λ  λ(H). Moreover, since K(2) is a sum of (1 + t)tp−1  (1 + t)t(ln m)C n no 0 t 1 + min t, . H + 1 non-negative definite operators, we expect that λ gets . mp/2 m λ larger, if the depth H is larger. (19) The NTH, i.e., (9) and (10), is just an infinite sequence of relationship. It cannot be used to determine the dynamic of We remark that the error terms, i.e., the righthand sides of NTK. However, thanks to the a priori estimates of the higher (18) and (19) can be arbitrarily small, provided that p is ∗ p order kernels (12), it holds that for any p 6 p with high large enough. In other words, if we take large enough, (p+1) C p/2 the truncated NTH (15) can approximate the original dy- probability kKt k∞ . (ln m) /m . The derivative (p) namic (9), (10) up to any precision provided that the time ∂tKt is an expression involves the higher order kernel (p+1) constraint (17) is satisfied. To better illustrate the scaling Kt , which is small provided that p is large enough. in Theorem 2.6, let us treat λ, ε, n as constants, ignore the Therefore, we can approximate the original NTH (10) by log n, log m factors and focus on the trade-off between time (p) ∂ K = 0 p∗/2(p∗+1) simply setting t t . In this way, we obtain the fol- t and width m. Then (17) simplifies to t . m . If lowing truncated hierarchy of ordinary differential equations p∗ is sufficiently large, the exponent approaches 1/2, and of p levels, which we call the truncated NTH, the condition is equivalent to that t is much smaller than 1/2 n m . In this regime, (18) simplifies to ˜ 1 X ˜ (2) ˜ ∂tfα(t) = − Kt (xα, xβ)(fβ(t) − yβ) 1/2 n  n  p β=1 X  t  (f (t) − f˜ (t))2 √ , (20) ˜ (r)  β β  . m ∂tKt (xα1 , xα2 , ··· , xαr ) β=1 n 1 X (r+1) = − K˜ (x , x , ··· , x , x )(f˜ (t) − y ), and similarly (19) simplifies to n t α1 α2 αr β β β β=1  p (p) (2) ˜ (2) t ∂ K˜ (x , x , ··· , x ) = 0. |K (xα, xβ) − K (xα, xβ)| √ . (21) t t α1 α2 αp t t . m (15) The error terms in (20) and (21) are smaller than any δ > 0, 2 r p − 1 √ where 6 6 , and provided we take p > log(1/δ)/ log( m/t). ˜ fβ(0) = fβ(0), β = 1, 2, ··· , n, Now if we take t  (n/λ) ln(n/ε), so that Corollary 2.5 guarantees the convergence of the dynamics. Consider two ˜ (r) (r) K0 = K0 , r = 2, 3, ··· , p. special cases: (i) If we take p = 2, then the error in (18) Dynamics of Deep Neural Networks and Neural Tangent Hierarchy is O(n7/2 ln(n/ε)3/λ3m), which is negligible provided 3. Technique overview that the width m is much bigger than n7/2. We conclude 7/2 In general, the summands appearing in kernel that if m is much bigger than n , the truncated NTH (r) gives a complete description of the original dynamic of Kt (xα1 , xα2 , ··· , xαr ) are product of inner prod- the NTK up to the equilibrium. The condition that m is ucts of vectors obtained in the following way: starting from much bigger than n7/2 is better than the previous best avail- one of the vectors 4 able one which requires m & n . (ii) If we take p = 3, at 1 (1) (2) (H) 9/2 4 4 3/2 √ , √ , {x , x , ··· , x } , then the error in (18) is O(n ln(n/ε) /λ m ) when m m β β β β∈{α1,α2,··· ,αr } t  (n/λ) ln(n/ε), which is negligible provided that the 3 width m is much bigger than n . We conclude that if m is (i) multiply one of the matrices much bigger than n3, the truncated NTH gives a complete ( ) description of the original dynamic of the NTK up to the W (2) (W (2))> W (H) (W (H))> √t , √t , ··· , √t , √t , equilibrium. Finally, we note that the estimates in Theorem m m m m 2.6 clearly improved for smaller t. 0 0 0 {σ1(xβ), σ2(xβ), ··· , σH (xβ)}β∈{α1,α2,··· ,αr }; The previous convergence theory of overparametrized neu- (23) ral networks works only for very wide neural networks, i.e., m n3 m n3 & . For any width (not necessary that & ), The- (ii) multiply one of the matrices orem 2.6 guarantees that the truncated NTH approximates (s) the training dynamics of deep neural networks. The effect diag(··· ), σ (xβ) diag(··· ) ··· diag(··· ), of the width appears in the approximation time and the error | {z } s−1 terms terms, (18) and (19), i.e., the wider the neural networks are, (24) the truncated dynamic (15) approximates the training dy- namic for longer time and the approximation error is smaller. for s > 2, where diag(··· ) is the diagonalization of a We recall from (7) that the NTK is the sum of H + 1 non- vector obtained by recursively using (i) and (ii). (2) PH+1 (`) negative definite operators, Kt = `=1 Gt . We expect that λ as in (16) gets bigger, if the depth H is larger. There- To describe the vectors appearing in fore, large width and depth makes the truncated dynamic (r) K (xα , xα , ··· , xα ) in a formal way, we need (15) a better approximation. t 1 2 r to introduce some more notations. We denote D0 the set of Thanks to Theorem 2.6, the truncated NTH (15) pro- expressions in the following form vides a good approximation for the evolution of the D := {e e ··· e e : 0 s 4H − 3} NTK. The truncated dynamic can be used to predict 0 s s−1 1 0 6 6 (25) the output of new data points. Recall that the training e d where j is chosen from the following sets: data are {(xβ, yβ)}16β6n ⊂ R × R. The goal is to predict the output of a new data point x. To do n √ (1) √ (2) √ (H) o e0 ∈ at, { mx , mx , ··· , mx }1 β n this, we can first use the truncated dynamic to solve β β β 6 6 for the approximated outputs {f˜ (t)} . Then β 16β6n and for 1 6 j 6 s, the prediction on the new test point x ∈ Rd can be estimated by sequentially solving the higher order kernels (( (2) (2) > (H) (H) > ) Wt (Wt ) Wt (Wt ) ˜ (p) p−1 ˜ (p−1) p−2 ˜ (2) ej ∈ √ , √ , ··· , √ , √ , Kt (x, X ), Kt (x, X ), ··· , Kt (x, X ) and m m m m f˜ (t), x 0 0 0 {σ1(xβ), σ2(xβ), ··· , σH (xβ)}16β6n} . n ˜ 1 X (2) ˜ We remark that from expression (7), each summand in ∂tfx(t) = − K˜ (x, xβ)(fβ(t) − yβ) n t K(2)(x , x ) β=1 t α1 α2 is of the form ∂ K˜ (r)(x, x , x , ··· , x ) hv (t), v (t)i hv (t), v (t)i hv (t), v (t)i t t α1 α2 αr−1 1 2 , 1 2 3 4 , n m m m 1 X ˜ (r+1) ˜ = − Kt (x, xα1 , xα2 , ··· , xαr −1, xβ)(fβ(t) − yβ), n where v1(t), v2(t), v3(t), v4(t) ∈ D0. But the set D0 con- β=1 (2) tains more terms than those appearing in Kt (·, ·). Given ∂ K˜ (p)(x, x , x , ··· , x ) = 0. t t α1 α2 αp−1 that we have constructed D0, D1, ··· , Dr, we denote Dr+1 (22) the set of expressions in the following form

D := {e e ··· e e : 0 s 4H − 3}, (26) where 2 6 r 6 p − 1. r+1 s s−1 1 0 6 6 Dynamics of Deep Neural Networks and Neural Tangent Hierarchy where ej is chosen from the following sets (notice that we Combining with the initial estimate of ξ(t), it follows that ∗ p 0 have included 1 in the following set, which does not appear ∗ C for time 0 6 t 6 m 2(p +1) /(ln m) , it holds that ξ(t) . in the definition of D0): C C (ln m) . Especially kvj(t)k∞ . (ln m) . Then the claim n √ (1) √ (2) √ (H) o (12) in Theorem 2.3 follows. e0 ∈ at, 1, { mxβ , mxβ , ··· , mxβ }16β6n , Thanks to the a priori estimate (12), we show that along the and for 1 6 j 6 s, ej belongs to one of the sets continuous time gradient descent, the higher order kernels K(r) vary slowly. We prove Corollary 2.4 and 2.5, and (( (2) (2) (H) (H) ) t W (W )> W (W )> Theorem 2.6 in the supplementary materials by a Gronwall¨ √t , √t , ··· , √t , √t , m m m m type argument. 0 0 0 {σ1(xβ), σ2(xβ), ··· , σH (xβ)}1 β n} , 6 6 4. Discussion and future directions {diag(d), d ∈ D0 ∪ D1 ∪ · · · ∪ Dr} , n (u+1) In this paper, we study the continuous time gradient descent σ` (xβ)diag(d1)diag(d2) ··· diag(du) : 1 6 ` 6 H, (gradient flow) of deep fully-connected neural networks. We 1 6 β 6 n, 1 6 u 6 r, d1, d2, ··· , du ∈ D0 ∪ D1 ∪ · · · ∪ Dr}show. that the training dynamic is given by a data dependent infinite hierarchy of ordinary differential equations, i.e., the Moreover, the total number of diag operations in the expres- NTH. We also show that this dynamic of the NTH can sion eses−1 ··· e1e0 ∈ Dr+1 is exactly r + 1. We remark be approximated by a finite truncated dynamic up to any that if d ∈ Ds, then it contains s diag operations. On the precision. This description makes it possible to directly other hand, by definition, we view diag(d) as an element study the change of the NTK for deep neural networks. with s + 1 diag operations because the diag in diag(d) Here we list some future directions. counted as one diag operation. 1. We mainly study deep fully-connected neural networks We will show in the supplementary materials that each sum- in this paper, we believe the same statements can be proven (r) mand in Kt (xα1 , xα2 , ··· , xαr ) is of the form for convolutional and residual neural networks. s 1 Y hv2j−1(t), v2j(t)i 2. We focus on the continuous time gradient descent for , 1 s r, r/2−1 6 6 simplicity. Our approach developed here can be generalized m m (27) j=1 to analyze discrete time gradient descent with small step v1(t), v2(t), ··· , v2s(t) ∈ D0 ∪ D1 ∪ · · · ∪ Dr−2. size. We elaborate the main idea here. The discrete time gradient descent is given by (r) n The initial value K0 (xα1 , xα2 , ··· , xαr ) can be estimated η X θ = θ − η∇ L(θ ) = θ − ∇ f (t)(f (t) − y ), by successively conditioning based on the depth of the neu- t+1 t θ t t n θ β β β ral network. A convenient scheme is given by the tensor β=1 program (Yang, 2019), which was developed to charac- where η is the learning rate. We write the NTK as (2) terize the scaling limit of neural network computations. K (xα, xβ; θt) to make the dependence on θt explicit. To (2) In the supplementary materials, we show at time t = 0, estimate the NTK K (xα1 , xα2 ; θt+1) at time t + 1, we those vectors vj(0) in (27) are combinations of projections use the taylor expansion, of independent Gaussian vectors. As a consequence, we s−1 X (−η)r have that hv2j−1(0), v2j(0)i/m concentrates around cer- K(2)(x , x ; θ ) = K(2)(x , x ; θ ) + α1 α2 t+1 α1 α2 t nr tain constant with high probability. So does the product r=3 Qs j=1hv2j−1(0), v2j(0)i/m. This gives the claim (11). X (r) K (xα1 , xα2 , xβ1 , ··· , xβr−2 ; θt) In the supplementary materials, we consider the quantity: 16β1,β2,··· ,βr−26n

× (fβ1 (t) − yβ1 ) ··· (fβr−2 (t) − yβr−2 ) + Ωs, ξ(t) = max{kvj(t)k∞ : vj(t) ∈ D0 ∪ D1 ∪ · · · ∪ Dp∗−1}. (29) Again using the tensor program, we show that with high where Ω is the error term in the truncation and the higher C s probability kvj(0)k∞ . (ln m) . This gives the estimate of order kernels K(r) are given by ξ(t) at t = 0. Next we show that the (p∗ +1)-th derivative of (r) ξ(t) can be controlled by itself. This gives a self-consistent K (xα1 , xα2 , xβ1 , ··· , xβr−2 ; θt) differential equation of ξ(t): (r−2) (2) = ∇θ K (xα1 , xα2 ; θt)(∇θfβ1 (t), ··· , ∇θfβr−2 (t)). 2p∗ (p∗+1) ξ(t) K(r) K(r) ∂ ξ(t) . (28) We notice that is different from in the NTH hier- t . mp∗/2 archy. A similar argument as for (12) can be used to derive Dynamics of Deep Neural Networks and Neural Tangent Hierarchy the a priori estimates of these kernels K(r). We expect to be interesting to see if this is the best possible condition for (r) C r/2−1 have that kK k∞ . (ln m) /m with high proba- approximating deep neural networks by the NTH, i.e., if bility with respect to the random initialization. Therefore one can show with some example that deep neural network C s 2 s/2−1 kΩsk∞ 6 O((ln m) η /n m √ ), which can be arbitrar- converges for m  n, while the approximation by the NTH ily small provided that η  m and s is large enough. fails. Similar procedure can be applied to NTH in the discrete time dynamics and we will not get into details here. To con- Acknowledgements clude√ this discussion, we believe that, under the assumption η  m, our analysis in continuous time can be carried The work of J.H. is supported by the Institute for Advanced over to the discrete time dynamics. Study. The work of H.-T. Y. is partially supported by NSF Grants DMS-1606305 and DMS-1855509, and a Simons 3. It will be interesting to further analyze the behaviors of Investigator award. The authors would like to thank the the truncated dynamics (15), and understand why it is better reviewers for their thoughtful comments and efforts towards than kernel regression using the limiting NTK. For example, improving our manuscript. if we truncate the dynamic at p = 3,

1 X (2) ∂ f˜ (t) = − K˜ (x , x )(f˜ (t) − y ), References t α n t α β β β j Allen-Zhu, Z. and Li, Y. What can resnet learn efficiently, ˜ (2) arXiv preprint arXiv:1905.10337 ∂tKt (xα1 , xα2 ) going beyond kernels? , (30) 2019. 1 X (3) = − K˜ (x , x , x )(f˜ (t) − y ), n t α1 α2 β β β β Allen-Zhu, Z., Li, Y., and Liang, Y. Learning and generaliza- ∂ K˜ (3)(x , x , x ) = 0. tion in overparameterized neural networks, going beyond t t α1 α2 α3 two layers. arXiv preprint arXiv:1811.04918, 2018a. The difference between (30) and the kernel regression using the limiting NTK is that in (30), the kernel K(2) Allen-Zhu, Z., Li, Y., and Song, Z. A convergence theory t In ICML, changes along time at a rate of O((ln m)C/m). We de- for via over-parameterization. ˜ ˜ arXiv:1811.03962, 2018b. note the residue vector as r˜(t) = (f1(t) − y1, f2(t) − ˜ > ˜ (3) y2, ··· , fn(t) − yn) . Since Kt does not depend on time Araujo,´ D., Oliveira, R. I., and Yukimura, D. A mean-field (2) t, we can integrate the second equation in (30) to get Kt = limit for certain deep neural networks. arXiv preprint (0) 1 (3) R t ˜ R t arXiv:1906.00193, 2019. Kt − n K0 [ 0 r˜(s)ds]. We denote R(t) = 0 r˜(s)ds, and plug the previous relation to the first equation of (30), to obtain the following system of ordinary differential equa- Arora, S., Du, S. S., Hu, W., Li, Z., Salakhutdinov, R., and tions Wang, R. On exact computation with an infinitely wide neural net. arXiv preprint arXiv:1904.11955, 2019a. 1 (2) 1 (3) ∂tr˜(t) = − K [˜r(t)] − K˜ [R˜(t), r˜(t)], n 0 n 0 Arora, S., Du, S. S., Hu, W., Li, Z., and Wang, R. Fine- ∂tR˜(t) =r ˜(t). grained analysis of optimization and generalization for overparameterized two-layer neural networks. arXiv The above system of ordinary differential equations can be preprint arXiv:1901.08584, 2019b. easily solved numerically. It will be interesting to under- (2) stand, under what conditions, the change of the kernel Kt Bhojanapalli, S., Neyshabur, B., and Srebro, N. Global helps optimization and generalization. optimality of local search for low rank matrix recovery. In Advances in Neural Information Processing Systems, 4. The optimal condition to use the NTH approximation. 3 pp. 3873–3881, 2016. We have shown that m & n guarantees the approximation of the NTH to the deep neural network up to both of them Cao, Y. and Gu, Q. Generalization bounds of stochastic find global minimizers. Our analysis loses a factor n by gradient descent for wide and deep neural networks. In (r) estimating the norm of the kernels K using their entry- Advances in Neural Information Processing Systems, pp. wise L∞ norm. This results in a loss of a factor 1/n. We 10835–10845, 2019. believe that this loss can be recovered by a more careful 2 analysis and one can improve the condition to m & n . To Chizat, L. and Bach, F. On the global convergence of gradi- further improve on this condition, one will have to analyze ent descent for over-parameterized models using optimal other cancellation effects. Assuming such an analysis is transport. In Advances in neural information processing possible, one might reach the condition m & n. It would systems, pp. 3036–3046, 2018. Dynamics of Deep Neural Networks and Neural Tangent Hierarchy

Choromanska, A., Henaff, M., Mathieu, M., Ben Arous, G., Jacot, A., Gabriel, F., and Hongler, C. Neural tangent kernel: and LeCun, Y. The loss surfaces of multilayer networks. Convergence and generalization in neural networks. In In Proceedings of the Eighteenth International Confer- Advances in neural information processing systems, pp. ence on Artificial Intelligence and Statistics, pp. 192–204, 8571–8580, 2018. 2015. Jin, C., Ge, R., Netrapalli, P., Kakade, S. M., and Jordan, Collobert, R., Weston, J., Bottou, L., Karlen, M., M. I. How to escape saddle points efficiently. In Proceed- Kavukcuoglu, K., and Kuksa, P. Natural language pro- ings of the 34th International Conference on Machine cessing (almost) from scratch. Journal of machine learn- Learning-Volume 70, pp. 1724–1732. JMLR. org, 2017. ing research, 12(Aug):2493–2537, 2011. Kawaguchi, K. Deep learning without poor local minima. Dauphin, Y. N., Pascanu, R., Gulcehre, C., Cho, K., Gan- In Advances in Neural Information Processing Systems, guli, S., and Bengio, Y. Identifying and attacking the pp. 586–594, 2016. saddle point problem in high-dimensional non-convex op- timization. In Advances in Neural Information Processing Kawaguchi, K. and Huang, J. Gradient descent finds global Systems, pp. 2933–2941, 2014. minima for generalizable deep neural networks of practi- cal sizes. arXiv preprint arXiv:1908.02419, 2019. Devlin, J., Chang, M.-W., Lee, K., and Toutanova, K. Bert: Pre-training of deep bidirectional transformers for lan- Kawaguchi, K. and Kaelbling, L. P. Elimination of all guage understanding. arXiv preprint arXiv:1810.04805, bad local minima in deep learning. arXiv preprint 2018. arXiv:1901.00279, 2019.

Du, S. S., Lee, J. D., Li, H., Wang, L., and Zhai, X. Gradient Krizhevsky, A., Sutskever, I., and Hinton, G. E. Imagenet descent finds global minima of deep neural networks. classification with deep convolutional neural networks. ICML, arXiv:1811.03804, 2018a. In Advances in neural information processing systems, pp. 1097–1105, 2012. Du, S. S., Zhai, X., Poczos, B., and Singh, A. Gradient descent provably optimizes over-parameterized neural LeCun, Y., Bottou, L., Bengio, Y., and Haffner, P. Gradient- networks. In ICLR, arXiv:1810.02054, 2018b. based learning applied to document recognition. Proceed- ings of the IEEE, 86(11):2278–2324, 1998. Dyer, E. and Gur-Ari, G. Asymptotics of wide networks from feynman diagrams. arXiv preprint Lee, J., Xiao, L., Schoenholz, S. S., Bahri, Y., Sohl- arXiv:1909.11304, 2019. Dickstein, J., and Pennington, J. Wide neural networks of any depth evolve as linear models under gradient descent. Ge, R., Huang, F., Jin, C., and Yuan, Y. Escaping from arXiv preprint arXiv:1902.06720, 2019. saddle points—online stochastic gradient for tensor de- composition. In Proceedings of The 28th Conference on Lee, J. D., Simchowitz, M., Jordan, M. I., and Recht, B. Learning Theory, pp. 797–842, 2015. Gradient descent only converges to minimizers. In Con- ference on learning theory, pp. 1246–1257, 2016. Ge, R., Lee, J. D., and Ma, T. Matrix completion has no spu- rious local minimum. In Advances in Neural Information Li, Y. and Liang, Y. Learning overparameterized neural Processing Systems, pp. 2973–2981, 2016. networks via stochastic gradient descent on structured data. In Advances in Neural Information Processing Ghorbani, B., Mei, S., Misiakiewicz, T., and Montanari, A. Systems, pp. 8157–8166, 2018. Linearized two-layers neural networks in high dimension. arXiv preprint arXiv:1904.12191, 2019. Liang, S., Sun, R., Lee, J. D., and Srikant, R. Adding one neuron can eliminate all bad local minima. In Advances Glorot, X. and Bengio, Y. Understanding the difficulty in Neural Information Processing Systems, 2018. of training deep feedforward neural networks. In Pro- ceedings of the thirteenth international conference on Mei, S., Misiakiewicz, T., and Montanari, A. Mean- artificial intelligence and statistics, pp. 249–256, 2010. field theory of two-layers neural networks: dimension- free bounds and kernel limit. arXiv preprint Hinton, G., Deng, L., Yu, D., Dahl, G., Mohamed, A.-r., arXiv:1902.06015, 2019. Jaitly, N., Senior, A., Vanhoucke, V., Nguyen, P., Kings- bury, B., et al. Deep neural networks for acoustic mod- Nguyen, P.-M. Mean field limit of the learning dy- eling in speech recognition. IEEE Signal processing namics of multilayer neural networks. arXiv preprint magazine, 29, 2012. arXiv:1902.02880, 2019. Dynamics of Deep Neural Networks and Neural Tangent Hierarchy

Park, D., Kyrillidis, A., Caramanis, C., and Sanghavi, S. Yang, G. Scaling limits of wide neural networks with Non-square matrix sensing without spurious local min- weight sharing: Gaussian process behavior, gradient in- ima via the burer-monteiro approach. arXiv preprint dependence, and neural tangent kernel derivation. CoRR, arXiv:1609.03240, 2016. abs/1902.04760, 2019. URL http://arxiv.org/ abs/1902.04760. Sainath, T. N., Mohamed, A.-r., Kingsbury, B., and Ram- abhadran, B. Deep convolutional neural networks for Yehudai, G. and Shamir, O. On the power and limitations lvcsr. In 2013 IEEE international conference on acous- of random features for understanding neural networks. tics, speech and signal processing, pp. 8614–8618. IEEE, arXiv preprint arXiv:1904.00687, 2019. 2013. Zhang, G., Martens, J., and Grosse, R. B. Fast convergence Silver, D., Huang, A., Maddison, C. J., Guez, A., Sifre, L., of natural gradient descent for over-parameterized neural Van Den Driessche, G., Schrittwieser, J., Antonoglou, I., networks. In Advances in Neural Information Processing Panneershelvam, V., Lanctot, M., et al. Mastering the Systems, pp. 8080–8091, 2019. game of go with deep neural networks and tree search. Nature, 529(7587):484–489, 2016. Zou, D. and Gu, Q. An improved analysis of training over- parameterized deep neural networks. In Advances in Silver, D., Schrittwieser, J., Simonyan, K., Antonoglou, Neural Information Processing Systems, pp. 2053–2062, I., Huang, A., Guez, A., Hubert, T., Baker, L., Lai, M., 2019. Bolton, A., et al. Mastering the game of go without human knowledge. Nature, 550(7676):354, 2017. Zou, D., Cao, Y., Zhou, D., and Gu, Q. Stochastic gradient descent optimizes over-parameterized deep relu networks. Sirignano, J. and Spiliopoulos, K. Mean field analysis of arXiv preprint arXiv:1811.08888, 2018. deep neural networks. arXiv preprint arXiv:1903.04440, 2019. Song, M., Montanari, A., and Nguyen, P. A mean field view of the landscape of two-layers neural networks. Proceed- ings of the National Academy of Sciences, 115:E7665– E7671, 2018. Song, Z. and Yang, X. Quadratic suffices for over- parametrization via matrix chernoff bound. arXiv preprint arXiv:1906.03593, 2019. Sun, J., Qu, Q., and Wright, J. Complete dictionary recovery over the sphere i: Overview and the geometric picture. IEEE Transactions on Information Theory, 63(2):853– 884, 2016. Sun, J., Qu, Q., and Wright, J. A geometric analysis of phase retrieval. Foundations of Computational Mathematics, 18 (5):1131–1198, 2018. Szegedy, C., Liu, W., Jia, Y., Sermanet, P., Reed, S., Anguelov, D., Erhan, D., Vanhoucke, V., and Rabinovich, A. Going deeper with convolutions. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1–9, 2015. Wu, Y., Schuster, M., Chen, Z., Le, Q. V., Norouzi, M., Macherey, W., Krikun, M., Cao, Y., Gao, Q., Macherey, K., et al. Google’s neural machine translation system: Bridging the gap between human and machine translation. arXiv preprint arXiv:1609.08144, 2016. Xie, B., Liang, Y., and Song, L. Diverse neural net- work learns true target functions. arXiv preprint arXiv:1611.03131, 2016.