arXiv:1909.08156v1 [cs.LG] 18 Sep 2019 award. epnua ewrshv eoepplrdet hi unpreced their [ to recognition due Image popular become tasks. have learning networks neural Deep Introduction 1 agaeudrtnig[ understanding language h oko .T .i atal upre yNFGat DMS- Grants NSF by supported partially is Y. H.-T. of work The agn enl(T)a nrdcdi [ in introduced as (NTK) kernel tangent ae tpsil odrcl td h hneo h T o d for NTK arb the kerne to outperform of up networks change NTK. neural NTK the deep study the that observation directly of the to dynamic dee s possible the data it the the approximates makes of and NTH width dynamic network of neural descent hierarchy the gradient on conditions the certain differe captures ordinary of which hierarchy the (NTH) infinite an of derive change We networks. networ The neural deep effect. of width features likely finite generalization is the the gap to describe performance due This training networks. along over neural deep deep the a and for [ loss in training regressio observed was kernel zero the achieves to descent equivalent gradient con [ indeed stays is papers it network recent and neural other kernel some limiting in explicit implicit an to converges NTK yaiso epNua ewrsadNua agn Hierarc Tangent Neural and Networks Neural Deep of Dynamics h vlto fade erlntoktandb h gradien the by trained network neural deep a of evolution The ntecretppr esuytednmco h T o finite for NTK the of dynamic the study we paper, current the In -al [email protected] E-mail: ioagHuang Jiaoyang 5 htteei efrac a ewe h enlregressi kernel the between gap performance a is there that ] IAS 10 , 12 , 44 r utafwo h eetaheeet.Hwvr n aspec one However, achievements. recent the of few a just are ] 25 , 26 6 , , 42 13 20 , ,sec eonto [ recognition speech ], 14 ,weei a rvnta nteifiiewdhlmtthe limit width infinite the in that proven was it where ], .I h vraaerzto eie ul-rie deep fully-trained a regime, overparametrization the In ]. Abstract 1 ks. rdco sn h iiigNK n the And NTK. limiting the using predictor n ersin sn h orsodn limiting corresponding the using regressions l 660 n M-850,adaSmn Investigator Simons a and DMS-1855509, and 1606305 ta qain,tenua agn hierarchy tangent neural the equations, ntial aaeeie erlntok oee,it However, network. neural parameterized tdmnin epoeta h truncated the that prove we dimension, et tn uigtann.TeNKwsalso was NTK The training. during stant ooiiaefo h hneo h NTK the of change the from originate to e erlntok,adseslgton light sheds and networks, neural eep ecn a edsrbdb t neural its by described be can descent t -al [email protected] E-mail: T ln h riigi eta to central is training the along NTK erlntok oevr under Moreover, network. neural p 19 ne ucs navreyo machine of variety a in success ented , taypeiin hsdescription This precision. itrary it epflycnetdneural fully-connected deep width 34 avr University Harvard on-zrYau Horng-Tzer ,paigG [ Go playing ], nuigtelmtn NTK limiting the using on 35 , 36 n natural and ] hy fdeep of t neural networks that is not well understood is training. Training a deep neural network is usually done via a gradient decent based algorithm. Analyzing such training dynamics is challenging. Firstly, as highly nonlinear structures, deep neural networks usually involve a large number of parameters. Secondly, as highly non-convex optimization problems, there is no guarantee that a gradient based algorithm will be able to find the optimal parameters efficiently during the training of neural networks. One question then arises: given such complexities, is it possible to obtain a succinct description of the training dynamics?

In this paper, we focus on the empirical risk minimization problem with the quadratic

n 1 2 min L(θ)= (f(xα,θ) yα) , θ 2n − α=1 X n n where xα α=1 are the training inputs, yα α=1 are the labels, and the dependence is modeled by a deep fully-connected{ } feedforward neural network{ with} H hidden layers. The network has d input nodes, and the input vector is given by x Rd. For 1 6 ℓ 6 H, the ℓ-th hidden layer has m neurons. Let x(ℓ) be the output of the ℓ-th layer with∈x(0) = x. Then the feedforward neural network is given by the set of recursive equations:

(ℓ) 1 (ℓ) (ℓ 1) x = σ(W x − ), ℓ =1, 2, ,H, (1.1) √m ···

(ℓ) m d (ℓ) m m where W R × if ℓ = 1 and W R × if 2 6 ℓ 6 H are the weight matrices, and σ is the activation unit, which∈ is applied coordinate-wise∈ to its input. The output of the neural network is

(H) f(x, θ)= a⊤x R, (1.2) ∈ where a Rm is the weight matrix for the output layer. We denote the vector containing all trainable param- eters by ∈θ = (vec(W (1)), vec(W (2)) ..., vec(W (H)),a). We remark that this parametrization is nonstandard because of those 1/√m factors. However, it has already been adopted in several recent works [13,14,20,27]. We note that the predictions and training dynamics of (1.1) are identical to those of standard networks, up to a scaling factor 1/√m in the learning rate for each parameter.

We initialize the neural network with random Gaussian weights following the Xavier initialization scheme [18]. More precisely, we set the initial parameter vector θ as W (ℓ) (0, σ2 ), a (0, σ2). In this way, 0 ij ∼ N w i ∼ N a for the randomly initialized neural network, we have that the L2 norms of the output of each layer are of (ℓ) 2 order one, i.e. x 2 = O(1) for 0 6 ℓ 6 H, and f(x, θ0) = O(1) with high probability. In this paper, we train all layersk of thek neural network with continuous time (gradient flow): for any time t > 0

(ℓ) ∂ W = ∂ (ℓ) L(θ ), ℓ =1, 2, , H, ∂ a = ∂ L(θ ), (1.3) t t − W t ··· t t − a t

(1) (2) (H) where θt = (vec(Wt ), vec(Wt ) ..., vec(Wt ),at).

(ℓ) (ℓ 1) For simplicity of notations, we write σ(W x − ) as σℓ(x), or simply σℓ if the context is clear. We (ℓ) (ℓ 1) (1) (r) (ℓ) (ℓ 1) write its derivative diag(σ′(W x − )) as σℓ′ (x) = σℓ (x), and r-th derivative diag(σ (W x − )) as (r) (r) (r) σℓ (x), or σℓ for r > 1. In this notation, σℓ (x) are diagonal matrices. With those notations, explicitly,

2 the continuous time gradient descent dynamic (1.3) is

(ℓ) ∂ W = ∂ (ℓ) L(θ ) t t − W t n (ℓ+1) 1 (Wt )⊤ at (ℓ 1) (1.4) = σℓ′ (xβ) σH′ (xβ) (xβ − )⊤(f(xβ,θt) yβ), −n √m ··· √m! ⊗ − βX=1 for ℓ =1, 2, ,H, and ··· n 1 ∂ a = ∂ L(θ )= x(H)(f(x ,θ ) y ). (1.5) t t − a t −n β β t − β βX=1

1.1 Neural Tangent Kernel

A recent paper [20] introduced the Neural Tangent Kernel (NTK) and proved the limiting NTK captures the behavior of fully-connected deep neural networks in the infinite width limit trained by gradient descent:

∂ f(x, θ )= ∂ f(x, θ )∂ θ = ∂ f(x, θ )∂ L(θ ) t t θ t t t − θ t θ t n n 1 1 (1.6) = ∂ f(x, θ ) ∂ f(x ,θ )(f(x ,θ ) y )= K(2)(x, x )(f(x ,θ ) y ), −n θ t θ β t β t − β −n t β β t − β βX=1 βX=1 where the NTK K(2)( , ) is given by t · · H+1 K(2)(x , x )= ∂ f(x ,θ ), ∂ f(x ,θ ) = G(ℓ)(x , x ) (1.7) t α β h θ α t θ β t i t α β Xℓ=1 and for 1 6 ℓ 6 H,

(ℓ) G (x , x )= ∂ (ℓ) f(x ,θ ), ∂ (ℓ) f(x ,θ ) t α β h W α t W β t i (ℓ+1) (ℓ+1) (Wt )⊤ at (Wt )⊤ at (ℓ 1) (ℓ 1) = σℓ′ (xα) σH′ (xα) , σℓ′ (xβ) σH′ (xβ) xα − , xβ − * √m ··· √m √m ··· √m+ h i

and

G(H+1) = ∂ f(x ,θ ), ∂ f(x ,θ ) = x(H), x(H) . t h a α t a β t i h α β i (2) The NTK Kt ( , ) varies along training. However, in the infinite width limit, the training dynamic is very · · (2) (2) simple: The NTK does not change along training, Kt ( , ) = K ( , ). The network function f(x, θt) ∞ follows a linear differential equation [20]: · · · ·

n 1 (2) ∂tf(x, θt)= K (x, xβ )(f(xβ,θt) yβ), (1.8) −n ∞ − βX=1

3 which becomes analytically tractable. In other words, the training dynamic is equivalent to the kernel regression using the limiting NTK K(2)( , ). While the linearization (1.8) is only exact in the infinite width ∞ limit, for a sufficiently wide deep neural· network,· (1.8) still provides a good approximation of the learning dynamic for the corresponding deep neural network [13, 14, 27]. As a consequence, it was proven in [13, 14] that, for a fully-connected wide neural network with m & n4 under certain assumptions on the data set, the gradient descent converges to zero training loss at a linear rate. Although highly overparametrized neural networks is equivalent to the kernel regression, it is possible to show that the class of finite width neural networks is more expressive than the limiting NTK. It has been constructed in [1,17,46] that there are simple functions that can be efficiently learnt by finite width neural networks, but not the kernel regression using the limiting NTK.

1.2 Contribution

There is a performance gap between the kernel regression (1.8) using the limiting NTK and the deep neural networks. It was observed in [5] that the convolutional neural networks outperform their corresponding limit- ing NTK by 5% - 6%. This performance gap is likely to originate from the change of the NTK along training due to the finite width effect. The change of the NTK along training has its benefits on generalization. In the current paper, we study the dynamic of the NTK for finite width deep fully-connected neural networks. Here we summarize our main contributions: We show the gradient descent dynamic is captured by an infinite hierarchy of ordinary differential • equations, the neural tangent hierarchy (NTH). Different from the limiting NTK (1.7), which depends only on the neural network architecture, the NTH is data dependent and capable of learning data- dependent features. We derive a priori estimates of the higher order kernels involved in the NTH. Using these a priori • estimates as input, we confirm a numerical observation in [27] that the NTK varies at a rate of order O(1/m). As a corollary, this implies that for a fully-connected wide neural network with m & n3, the gradient descent converges to zero training loss at a linear rate, which improves the results in [13]. The NTH is just an infinite sequence of relationship. Without truncation, it cannot be used to de- • termine the dynamic of the NTK. Using the a priori estimates of the higher order kernels as input, we construct a truncated hierarchy of ordinary differential equations, the truncated NTH. We show that this system of truncated equations approximates the dynamic of the NTK to certain time up to arbitrary precision. This description makes it possible to directly study the change of the NTK for deep neural networks.

1.3 Notations

In the paper, we fix a large constant p > 0, which appears in Assumptions (2.1) and (2.2). We use c, C to represent universal constants, which might be different from line to line. In the paper, we write a = O(b) or a . b if there exists some large universal constant C such that a 6 Cb. We write a & b if there exists some small universal constant c > 0 such that a > cb. We write a | |b if there exist universal constants c, C such that cb 6 a 6 Cb. We reserve n for the number of input samples≍ and m for the width of the neural network. For practical| | neural networks, we always have that m . poly(n) and n . poly(m). We denote

4 the set of input samples as = x , x , , x . For simplicity of notations, we write the output of the X { 1 2 ··· n} neural network as fβ(t) = f(xβ,θt). We denote vector L2 norm as 2, vector or function L norm as , matrix spectral norm as , and matrix Frobenius normk·k as . We say that an event∞ holds 2 2 c F k·k∞ k·k → m k·k with high probability, if it holds with probability at least 1 e− for some c > 0. Then the intersection of poly(n,m) many high probability events is still a high probability− event, provided m is large enough. In the paper, we treat cr, Cr in Assumption 2.1 and 2.2, and the depth H as constants. We will not keep track of them.

1.4 Related Work

In this section, we survey an incomplete list of previous works on optimization aspect of deep neural net- works.

Because of the highly non-convexity nature of deep neural networks, the gradient based algorithms can potentially get stuck near a critical point, i.e., saddle point or local minimum. So one important question in deep neural networks is: what does the loss landscape look like. One promising candidate for loss landscapes is the class of functions that satisfy: (i) all local minima are global minima and (ii) there exists a negative curvature for every saddle point. A line of recent results show that, in many optimization problems of interest [7, 15, 16, 33, 40, 41], loss landscapes are in such class. For this function class, (perturbed) gradient descent [15, 21, 28] can find a global minimum. However, even for a three-layer linear network, there exists a saddle point that does not have a negative curvature [22]. So it is unclear whether this geometry-based approach can be used to obtain the global convergence guarantee of first-order methods. Another approach is to show that practical deep neural networks allow some additional structure or assumption to make non- convex optimizations tractable. Under certain simplification assumptions, it has been proven recently that there are novel loss landscape structures in deep neural networks, which may play a role in making the optimization tractable [9, 11, 22, 24, 30].

Recently, it was proved in a series of papers that, if the size of a neural network is significantly larger than the size of the dataset, the (stochastic) gradient descent algorithm can find optimal parameters [3, 13, 14, 29, 39, 47]. In the overparametrization regime, a fully-trained deep neural network is indeed equivalent to the kernel regression predictor using the limiting NTK (1.8). As a consequence, the gradient descent achieves zero training loss for a deep overparameterized neural network. Under further assumptions, it can be shown that the trained networks generalize [2, 6]. Unfortunately, there is a significant gap between the overparametrized neural networks, which are provably trainable, and neural networks in common practice. Typically, deep neural networks used in practical applications are trainable, and yet, much smaller than what the previous theories require to ensure trainability. In [23], it is proven that gradient descent can find a global minimum for certain deep neural networks of sizes commonly encountered in practice.

Training dynamics of neural networks in the mean field setting have been studied in [4, 8, 31, 32, 37, 38]. Their mean field analysis describes distributional dynamics of neural network parameters via certain nonlinear partial differential equations, in the asymptotic regime of large network sizes and large number of stochastic gradient descent training iterations. However, their analysis is restricted to neural networks in the mean-field framework with a normalization factor 1/m, different from ours 1/√m, which is commonly used in modern networks [18].

5 2 Main results

Assumption 2.1. The σ is smooth, and for any 1 6 r 6 2p +1, there exists a constant (r) Cr > 0 such that the r-th derivative of σ satisfies σ (x) 6 Cr. k k∞ Assumption 2.1 is satisfied by using common activation units such as sigmoid and hyperbolic tangents. Moreover, the softplus activation, which is defined as σa(x) = ln(1 + exp(ax))/a, satisfies Assumption 2.1 with any hyperparameter a R>0. The softplus activation can approximate the ReLU activation for any desired accuracy as ∈

σ (x) relu(x) as a , a → → ∞ where relu represents the ReLU activation.

1 Assumption 2.2. There exists a small constant c > 0 such that the training inputs satisfy c < x 6 c− . k αk2 For any 1 6 r 6 2p+1, there exists a constant cr > 0 such that for any distinct indices 1 6 α1, α2, , αr 6 n, the smallest singular value of the data matrix [x , x , , x ] is at least c . ··· α1 α2 ··· αr r 1 For more general input data, we can always normalize them such that c < xα 2 6 c− . Under k (ℓk) this normalization, for the randomly initialized deep neural network, it holds that xα 2 = O(1) for all 1 6 ℓ 6 H, where the implicit constants depend on ℓ. The second part of Assumptionk 2.2krequires that for any small number of input data: x , x , , x , they are linearly independent. α1 α2 ··· αr (r) r Theorem 2.3. Under Assumptions 2.1 and 2.2, there exists an infinite family of operators Kt : R for r > 2, the continuous time gradient descent dynamic is given by an infinite hierarchy of ordinaryX differential7→ equations, i.e., the NTH,

n 1 ∂ (f (t) y )= K(2)(x , x )(f (t) y ), (2.1) t α − α −n t α β β − β βX=1 and for any r > 2,

n (r) 1 (r+1) ∂tK (xα1 , xα2 , , xαr )= K (xα1 , xα2 , , xαr , xβ)(fβ(t) yβ). (2.2) t ··· −n t ··· − βX=1 There exists a deterministic family (independent of m) of operators K(r) : r R for 2 6 r 6 p +1 and K(r) =0 if r is odd, such that with high probability with respect to the randomX initialization,7→ there exist some constants C, C′ > 0 such that

K(r) (ln m)C K(r) . , (2.3) 0 − mr/2 1 m(r 1)/2 − − ∞ p C′ and for 0 6 t 6 m 2(p+1) /(ln m) ,

(ln m)C (r) . Kt r/2 1 . (2.4) k k∞ m −

6 It was proven in [13,27] that the change of the NTK for a wide deep neural network is upper bounded by O(1/√m). However, the numerical experiments in [27] indicate the change of the NTK is closer to O(1/m). As a corollary of Theorem 2.3, we confirm the numerical observation that the NTK varies at a rate of order O(1/m). Corollary 2.4. Under Assumptions 2.1 and 2.2, the NTK K(2)( , ) varies at a rate of order O(1/m): with t · · high probability with respect to the random initialization, there exist some constants C, C′ > 0 such that for p C′ 0 6 t 6 m 2(p+1) /(ln m) , it holds

C (2) (1 + t)(ln m) ∂tKt . . k k∞ m

As another corollary of Theorem 2.3, for a fully-connected wide neural network with m & n3, the gradient descent converges to zero training loss at a linear rate. Corollary 2.5. Under Assumptions 2.1 and 2.2, we further assume that there exists λ > 0 (which might depend on n)

(2) λmin K0 (xα, xβ) > λ, (2.5) 16α,β6n h i and the width m of the neural network satisfies

3 n C 2 m > C′ (ln m) ln(n/ε) , (2.6) λ   for some large constants C, C′ > 0. Then with high probability with respect to the random initialization, the training error decays exponentially,

n 2 λt (f (t) y ) . ne− 2n , β − β βX=1 which reaches ε at time t (n/λ) ln(n/ε). ≍ It is proven in [13] that if there exists λ(H) > 0,

(H) (H) λmin G0 (xα, xβ) > λ , 16α,β6n h i then for m > C(n/λ(H))4 the gradient descent finds a global minimum. Corollary 2.5 improves this result in two ways: (i) We improve the quartic dependence of n to a cubic dependence. (ii) We recall that (2) H+1 (ℓ) (ℓ) (2) Kt = ℓ=1 G0 , and those kernels G0 are all non-negative definite. The smallest eigenvalue of K0 is typically much bigger than that of G(H), i.e., λ λ(H). Moreover, since K(2) is a sum of H +1 non-negative P 0 t definite operators, we expect that λ gets larger,≫ if the depth H is larger. The NTH, i.e., (2.1) and (2.2), is just an infinite sequence of relationship. It cannot be used to determine the dynamic of NTK. However, thanks to the a priori estimates of the higher order kernels (2.4), it holds that (p+1) C p/2 (p) with high probability Kt . (ln m) /m . The derivative ∂tKt is an expression involves the higher (p+1) k k∞ order kernel Kt , which is small provided that p is large enough. Therefore, we can approximate the

7 (p) original NTH (2.2) by simply setting ∂tKt = 0. In this way, we obtain the following truncated hierarchy of ordinary differential equations of p levels, which we call the truncated NTH,

n 1 ∂ f˜ (t)= K˜ (2)(x , x )(f˜ (t) y ), t α −n t α β β − β βX=1 1 n (2.7) ∂ K˜ (r)(x , x , , x )= K˜ (r+1)(x , x , , x , x )(f˜ (t) y ), 2 6 r 6 p 1, t t α1 α2 ··· αr −n t α1 α2 ··· αr β β − β − βX=1 ∂ K˜ (p)(x , x , , x )=0. t t α1 α2 ··· αp where

f˜ (0) = f (0), β =1, 2, , n, K˜ (r) = K(r), r =2, 3, , p. β β ··· 0 0 ···

In the following theorem, we show this system of truncated equations (2.7) approximates the dynamic of the NTK up to arbitrary precision, provided that p is large enough. Theorem 2.6. Under Assumptions 2.1 and 2.2, we take an even p and further assume that

(2) λmin K0 (xα, xβ) > λ. 16α,β6n h i Then there exist constants c, C, C′ > 0 such that for t

C p C′ t 6 min c λm/n/(ln m) ,m 2(p+1) /(ln m) , (2.8) { } the dynamic (2.1) can be approximated byp the truncated dynamic (2.7),

1/2 n (1 + t)tp 1√n n (f (t) f˜ (t))2 . − min t, , (2.9)  β − β  mp/2 λ βX=1 n o   and (1 + t)tp 1 (1 + t)t(ln m)C n K(2)(x , x ) K˜ (2)(x , x ) . − 1+ min t, . (2.10) | t α β − t α β | mp/2 m λ  n o We remark that the error terms, i.e., the righthand sides of (2.9) and (2.10) can be arbitrarily small, provided that p is large enough. In other words, if we the p large enough, the truncated NTH (2.7) can approximate the original dynamic (2.1), (2.2) up to any precision provided that the time constraint (2.8) is satisfied. Now if we take t (n/λ) ln(n/ε), so that Corollary 2.5 guarantees the convergence of the dynamics. Consider two special cases:≍ (i) If we take p = 2, then the error in (2.9) is O(n7/2 ln(n/ε)3/λ3m) when t (n/λ) ln(n/ε), which is negligible provided that the width m is much bigger than n7/2. We conclude that if≍m is much bigger than n7/2, the truncated NTH gives a complete description of the original dynamic of the NTK up to the equilibrium. The condition that m is much bigger than n7/2 is better than the previous best available one which requires m & n4. (ii) If we take p = 3, then the error in (2.9) is O(n9/2 ln(n/ε)4/λ4m3/2) when t (n/λ) ln(n/ε), which is negligible provided that the width m is much bigger than n3. We conclude ≍

8 that if m is much bigger than n3, the truncated NTH gives a complete description of the original dynamic of the NTK up to the equilibrium. Finally, we note that the estimates in Theorem 2.6 clearly improved for smaller t.

The previous convergence theory of overparametrized neural networks works only for very wide neural networks, i.e., m & n3. For any width (not necessary that m & n3), Theorem 2.6 guarantees that the truncated NTH approximates the training dynamics of deep neural networks. The effect of the width appears in the approximation time and the error terms, (2.9) and (2.10), i.e., the wider the neural networks are, the truncated dynamic (2.7) approximates the training dynamic for longer time and the approximation error is smaller. We recall from (1.7) that the NTK is the sum of H + 1 non-negative definite operators, (2) H+1 (ℓ) Kt = ℓ=1 Gt . We expect that λ gets bigger, if the depth H is larger. Therefore, large width and depth makes the truncated dynamic (2.7) a better approximation. P Thanks to Theorem 2.6, the truncated NTH (2.7) provides a good approximation for the evolution of the NTK. The truncated dynamic can be used to predict the output of new data points. Recall that the d training data are (x ,y ) 6 6 R R. The goal is to predict the output of a new data point x. To do { β β }1 β n ⊂ × this, we can first use the truncated dynamic to solve for the approximated outputs f˜β(t) 16β6n. Then the prediction on the new test point x Rd can be estimated by sequentially solving{ the higher} order kernels (p) p 1 (p 1) p 2 ∈ (2) K˜ (x, − ), K˜ − (x, − ), , K˜ (x, ) and f˜ (t), t X t X ··· t X x

n 1 ∂ f˜ (t)= K˜ (2)(x, x )(f˜ (t) y ), t x −n t β β − β βX=1 n ˜ (r) 1 ˜ (r+1) ˜ ∂tKt (x, xα1 , xα2 , , xαr−1 )= Kt (x, xα1 , xα2 , , xαr 1, xβ )(fβ(t) yβ), 2 6 r 6 p 1, ··· −n ··· − − − βX=1 ∂ K˜ (p)(x, x , x , , x )=0. t t α1 α2 ··· αp−1 (2.11)

3 Technique overview

We recall the NTK from (1.7),

K(2)(x , x )= x(H), x(H) + t α β h α β i H (ℓ+1) (ℓ+1) (Wt )⊤ at (Wt )⊤ at (ℓ 1) (ℓ 1) + σℓ′ (xα) σH′ (xα) , σℓ′ (xβ) σH′ (xβ) xα − , xβ − . * √m ··· √m √m ··· √m+ h i Xℓ=1

(2) The kernel Kt ( , ) is a sum of terms, which are product of inner products of vectors involving the quantities (ℓ) (ℓ) · · (2) at, Wt , xα and σℓ′ (xα). To compute the derivatives of Kt ( , ), we need the following ordinary differential · · (ℓ) (ℓ) equations derived by using (1.4), (1.5) and the chain rule, which characterize the dynamics of at, Wt , xα

9 (r) and σℓ (xα) along the gradient flow.

n 1 ∂ a = x(H)(f (t) y ), t t −n β β − β βX=1 (ℓ) n (ℓ+1) Wt 1 (Wt )⊤ at 1 (ℓ 1) ∂t = diag σℓ′ (xβ ) σH′ (xβ ) (xβ − )⊤(fβ(t) yβ), √m −n √m ··· √m! √m ⊗ − βX=1 (ℓ) n (ℓ+1) (Wt )⊤ 1 1 (ℓ 1) at⊤ Wt ∂t = xβ − σH′ (xβ) σℓ′ (xβ ) (fβ(t) yβ), √m −n √m ⊗ √m ··· √m ! − βX=1 ℓ n (ℓ) (k+1) (k+1) (ℓ) 1 Wt Wt (Wt )⊤ at ∂txα = diag σℓ′ (xα) σk′ (xα)σk′ (xβ) σH′ (xβ) −n √m ··· √m √m ··· √m! Xk=1 βX=1 1 (k 1) (k 1) x − , x − (f (t) y ), × √mh α β i β − β (r) (r+1) (ℓ) (ℓ 1) ∂tσℓ (xα)= σℓ (xα) diag(∂t(Wt xα − )) n (ℓ+1) 1 (r+1) (Wt )⊤ at (ℓ 1) (ℓ 1) = σℓ (xα) diag σℓ′ (xβ) σH′ (xβ) xα − , xβ − (fβ(t) yβ) −n √m ··· √m! h i − βX=1 ℓ 1 n (ℓ) (k+1) (k+1) − 1 (r+1) Wt Wt (Wt )⊤ at + σℓ (xα) diag σℓ′ 1(xα) σk′ (xα)σk′ (xβ) σH′ (xβ) −n √m − ··· √m √m ··· √m! Xk=1 βX=1 (k 1) (k 1) x − , x − (f (t) y ). × h α β i β − β

(k) We remark that the k = ℓ term on the right hand side of the expression in ∂txα is

n (ℓ+1) (k) 1 (Wt )⊤ at 1 (k 1) (k 1) ∂txα = diag σℓ′ (xα)σℓ′ (xβ) σH′ (xβ) xα − , xβ − (fβ(t) yβ). −n √m ··· √m! √mh i − βX=1

(k) All other cases with k<ℓ can be read clearly from the expression of ∂txα given above.

Using the chain rule and the above expressions, the derivative of K(2)( , ) is given by t · ·

n 1 ∂ K(2)(x , x )= K(3)(x , x , x )(f(x ,θ ) y ), t t α1 α2 −n t α1 α2 β β t − β βX=1

(3) (2) where Kt (xα1 , xα2 , xβ ) is the sum of all the possible terms from Kt (xα1 , xα2 ) by performing one of the

10 following replacement: a x(H), t → β (ℓ) (ℓ+1) Wt (Wt )⊤ at 1 (ℓ 1) diag σℓ′ (xβ ) σH′ (xβ ) (xβ − )⊤, √m → √m ··· √m! √m ⊗ (ℓ) (ℓ+1) (Wt )⊤ 1 (ℓ 1) at⊤ Wt xβ − σH′ (xβ ) σℓ′ (xβ) , √m → √m ⊗ √m ··· √m ! ℓ (ℓ) (k+1) (k+1) (ℓ) Wt Wt (Wt )⊤ at 1 (k 1) (k 1) xα diag σℓ′ (xα) σk′ (xα)σk′ (xβ ) σH′ (xβ ) xα − , xβ − , → √m ··· √m √m ··· √m! √mh i kX=1 (ℓ+1) (r) (r+1) (Wt )⊤ at (ℓ 1) (ℓ 1) σℓ (xα) σℓ (xα) diag σℓ′ (xβ ) σH′ (xβ ) xα − , xβ − → √m ··· √m! h i ℓ 1 (ℓ) (k+1) (k+1) − (r+1) Wt Wt (Wt )⊤ at (k 1) (k 1) + σℓ (xα) diag σℓ′ 1(xα) σk′ (xα)σk′ (xβ) σH′ (xβ) xα − , xβ − , √m − ··· √m √m ··· √m h i k=1 ! X (3.1)

m (r) with α = α1, α2, where 1 = (1, 1, , 1)⊤ R . By the same reasoning, the derivative of Kt is given by ··· ∈ 1 n ∂ K(r)(x , x , , x )= K(r+1)(x , x , , x x )(f(x ,θ ) y ), t t α1 α2 ··· αr −n t α1 α2 ··· αr β β t − β βX=1 (r+1) (r) where Kt (xα1 , xα2 , , xαr , xβ) is the sum of all the possible terms from Kt (xα1 , xα2 , , xαr ) by performing any of the replacements··· in (3.1) with α = α , α , , α . ··· 1 2 ··· r (3) The followings are some examples of terms in Kt (xα1 , xα2 , xα3 ) (H) (H) (Wt )⊤ (2) at at (Wt )⊤ at σH′ 1(xα1 ) σH (xα1 ) diag σH′ (xα3 ) , σH′ 1(xα2 ) σH′ (xα2 ) − √m √m √m − √m √m *   + (H 2) (H 2) (H 1) (H 1) 1 at at x − , x − x − , x − ; σ′ (xα ) , σ′ (xα ) × h α1 α2 iih α1 α3 i √m H 1 √m H 3 √m   (H) (H 1) (Wt )⊤ at (H 2) (H 2) σH′ 1(xα1 )xα3 − , σH′ 1(xα2 ) σH′ (xα2 ) xα1 − , xα2 − . × * − − √m √m+ h i (r) In general, from the construction, the summands appearing in Kt (xα1 , xα2 , , xαr ) are product of inner products of vectors obtained in the following way: starting from one of the vectors···

at 1 (1) (2) (H) , , xβ , xβ , , xβ β α1,α2, ,αr , (3.2) √m √m { ··· } ∈{ ··· } (i) multiply one of the matrices

(2) (2) (H) (H) Wt (Wt )⊤ Wt (Wt )⊤ , , , , , σ1′ (xβ ), σ2′ (xβ), , σH′ (xβ ) β α1,α2, ,αr ; (3.3) ( √m √m ··· √m √m ) { ··· } ∈{ ··· }

11 (ii) multiply one of the matrices

diag( ), σ(s)(x ) diag( ) diag( ), s > 2 (3.4) ··· β ··· ··· ··· s 1 terms − where diag( ) is the diagonalization of a vector| obtained{z by recursively} using 1) and 2). ··· To describe the vectors appearing in K(r)(x , x , , x ) in a formal way, we need to introduce some t α1 α2 ··· αr more notations. We denote D0 the set of expressions in the following form

D0 := eses 1 e1e0 :0 6 s 6 4H 3 , (3.5) { − ··· − }

where ej is chosen from the following sets:

(1) (2) (H) e a , √mx , √mx , , √mx 6 6 0 ∈ t { β β ··· β }1 β n n o and for 1 6 j 6 s,

(2) (2) (H) (H) e Wt (Wt )⊤ Wt (Wt )⊤ j , , , , , σ1′ (xβ ), σ2′ (xβ), , σH′ (xβ ) 16β6n . ∈ (( √m √m ··· √m √m ) { ··· } )

(2) We remark that from expression (1.7), each summand in Kt (xα1 , xα2 ) is of the form v (t), v (t) v (t), v (t) v (t), v (t) h 1 2 i , h 1 2 i h 3 4 i , m m m where v (t), v (t), v (t), v (t) D . But the set D contains more terms than those appearing in K(2)( , ). 1 2 3 4 ∈ 0 0 t · · Given that we have constructed D0, D1, , Dr, we denote Dr+1 the set of expressions in the following form ···

Dr+1 := eses 1 e1e0 :0 6 s 6 4H 3 , (3.6) { − ··· − } where ej is chosen from the following sets (notice that we have included 1 in the following set, which does not appear in the definition of D0):

(1) (2) (H) e a , 1, √mx , √mx , , √mx 6 6 , 0 ∈ t { β β ··· β }1 β n n o and for 1 6 j 6 s, ej belongs to one of the sets

(2) (2) (H) (H) Wt (Wt )⊤ Wt (Wt )⊤ , , , , , σ1′ (xβ), σ2′ (xβ ), , σH′ (xβ) 16β6n , (( √m √m ··· √m √m ) { ··· } ) diag(d), d D D D , { ∈ 0 ∪ 1 ∪···∪ r} σ(u+1)(x ) diag(d ) diag(d ) diag(d ):1 6 ℓ 6 H, ℓ β 1 2 ··· u n 1 6 β 6 n, 1 6 u 6 r, d , d , , d D D D . 1 2 ··· u ∈ 0 ∪ 1 ∪···∪ r o

12 Moreover, the total number of diag operations in the expression eses 1 e1e0 Dr+1 is exactly r +1. We − ··· ∈ remark that if d Ds, then it contains s diag operations. On the other hand, by definition, we view diag(d) as an element with∈ s + 1 diag operations because the diag in diag(d) counted as one diag operation.

(3) (2) The kernel Kt (xα1 , xα2 , xα3 ) is obtained from Kt (xα1 , xα2 ) by the replacements (3.1) and taking (3) α = α1, α2 and β = α3. The summands in Kt (xα1 , xα2 , xα3 ) are of the forms 1 v (t), v (t) 1 v (t), v (t) v (t), v (t) 1 v (t), v (t) v (t), v (t) v (t), v (t) h 1 2 i , h 1 2 i h 3 4 i , h 1 2 i h 3 4 i h 5 6 i , (3.7) √m m √m m m √m m m m

where v1(t), v2(t), , v6(t) D0 D1. The first two terms in (3.7) are obtained from using the replacements ··· ∈ ∪ (ℓ) (ℓ) for at, and the last two terms in (3.7) are obtained from using the replacements for Wt /√m, (Wt )⊤/√m, (ℓ) (r) (r) xα and σℓ (xα). More generally, we will show that each summand in Kt (xα1 , xα2 , , xαr ) is of the form ··· s 1 v2j 1(t), v2j (t) h − i, 1 6 s 6 r, v1(t), v2(t), , v2s(t) D0 D1 Dr 2. (3.8) mr/2 1 m ··· ∈ ∪ ∪···∪ − − j=1 Y

(r) The initial value K0 (xα1 , xα2 , , xαr ) can be estimated by successively conditioning based on the depth of the neural network. A convenient··· scheme is given by the tensor program [45], which was developed to characterize the scaling limit of neural network computations. In Appendix A, we show at time t = 0, those vectors vj (0) in (3.8) are combinations of projections of independent Gaussian vectors. As a consequence, we have that v2j 1(0), v2j (0) /m concentrates around certain constant with high probability. So does the s h − i product v2j 1(0), v2j (0) /m. This gives the claim (2.3). j=1h − i In AppendixQ B, we consider the quantity:

ξ(t) = max vj(t) : vj (t) D0 D1 Dp 1 . {k k∞ ∈ ∪ ∪···∪ − } C Again using the tensor program, we show that with high probability vj (0) . (ln m) . This gives the estimate of ξ(t) at t = 0. Next we show that the (p + 1)-th derivative of kξ(t) cank∞ be controlled by itself. This gives a self-consistent differential equation of ξ(t): ξ(t)2p ∂(p+1)ξ(t) . . (3.9) t mp/2 p C′ Combining with the initial estimate of ξ(t), it follows that for time 0 6 t 6 m 2(p+1) /(ln m) , it holds that C C ξ(t) . (ln m) . Especially vj(t) . (ln m) . Then the claim (2.4) in Theorem 2.3 follows. k k∞ Thanks to the a priori estimate (2.4), we show that along the continuous time gradient descent, the (r) higher order kernels Kt vary slowly. We prove Corollary 2.4 and 2.5, and Theorem 2.6 in Appendix C by a Gr¨onwall type argument.

4 Discussion and future directions

In this paper, we study the continuous time gradient descent (gradient flow) of deep fully-connected neural networks. We show that the training dynamic is given by a data dependent infinite hierarchy of ordinary

13 differential equations, i.e., the NTH. We also show that this dynamic of the NTK can be approximated by a finite truncated dynamic up to any precision. This description makes it possible to directly study the change of the NTK for deep neural networks. Here we list some future directions. Firstly, we mainly study deep fully-connected neural networks here, we believe the same statements can be proven for convolutional and residual neural networks. Secondly, in this paper, for simplicity, we focus on the continuous time gradient descent. Our approach developed here can be generalized to analyze discrete time gradient descent. We elaborate the main idea here. The discrete time gradient descent is given by

n η θ = θ η L(θ )= θ f (t)(f (t) y ), t+1 t − ∇θ t t − n ∇θ β β − β βX=1 (2) where η is the learning rate. We write the NTK as (xα, xβ ; θt) to make the dependence on θt explicit. To estimate the NTK (2)(x , x ; θ ) at time t +K 1, we use the taylor expansion, K α1 α2 t+1 p 1 − ( η)r (2)(x , x ; θ ) (2)(x , x ; θ )+ − K α1 α2 t+1 ≈ K α1 α2 t nr r=3 16β ,β , ,βr 6n (4.1) X 1 2X··· −2 (r)(x , x , x , , x ; θ )(f (t) y ) (f (t) y ), K α1 α2 β1 ··· βr−2 t β1 − β1 ··· βr−2 − βr−2 where the higher order kernels (r) are given by K (r) (r 2) (2) (x , x , x , , x ; θ )= − (x , x ; θ )( f (t), f (t), , f (t)). K α1 α2 β1 ··· βr−2 t ∇θ K α1 α2 t ∇θ β1 ∇θ β2 ··· ∇θ βr−2 A similar argument as for (2.4) can be used to derive the a priori estimates of these kernels (r). We expect (r) C r/2 1 K to have that . (ln m) /m − with high probability with respect to the random initialization. kK k∞ (2) Therefore the righthand side of (4.1) gives an approximation of the NTK (xα1 , xα2 ; θt+1) at time t +1 up to arbitrary precision, provided that p is large enough. This gives a descriptionK of the NTK dynamics under discrete time gradient descent.

References

[1] Z. Allen-Zhu and Y. Li. What can resnet learn efficiently, going beyond kernels? arXiv preprint arXiv:1905.10337, 2019. [2] Z. Allen-Zhu, Y. Li, and Y. Liang. Learning and generalization in overparameterized neural networks, going beyond two layers. arXiv preprint arXiv:1811.04918, 2018. [3] Z. Allen-Zhu, Y. Li, and Z. Song. A convergence theory for via over-parameterization. In ICML, arXiv:1811.03962, 2018. [4] D. Ara´ujo, R. I. Oliveira, and D. Yukimura. A mean-field limit for certain deep neural networks. arXiv preprint arXiv:1906.00193, 2019. [5] S. Arora, S. S. Du, W. Hu, Z. Li, R. Salakhutdinov, and R. Wang. On exact computation with an infinitely wide neural net. arXiv preprint arXiv:1904.11955, 2019.

14 [6] S. Arora, S. S. Du, W. Hu, Z. Li, and R. Wang. Fine-grained analysis of optimization and generalization for overparameterized two-layer neural networks. arXiv preprint arXiv:1901.08584, 2019. [7] S. Bhojanapalli, B. Neyshabur, and N. Srebro. Global optimality of local search for low rank matrix recovery. In Advances in Neural Information Processing Systems, pages 3873–3881, 2016. [8] L. Chizat and F. Bach. On the global convergence of gradient descent for over-parameterized models using optimal transport. In Advances in neural information processing systems, pages 3036–3046, 2018. [9] A. Choromanska, M. Henaff, M. Mathieu, G. Ben Arous, and Y. LeCun. The loss surfaces of multi- layer networks. In Proceedings of the Eighteenth International Conference on Artificial Intelligence and Statistics, pages 192–204, 2015. [10] R. Collobert, J. Weston, L. Bottou, M. Karlen, K. Kavukcuoglu, and P. Kuksa. Natural language processing (almost) from scratch. Journal of machine learning research, 12(Aug):2493–2537, 2011. [11] Y. N. Dauphin, R. Pascanu, C. Gulcehre, K. Cho, S. Ganguli, and Y. Bengio. Identifying and at- tacking the saddle point problem in high-dimensional non-convex optimization. In Advances in Neural Information Processing Systems, pages 2933–2941, 2014. [12] J. Devlin, M.-W. Chang, K. Lee, and K. Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018. [13] S. S. Du, J. D. Lee, H. Li, L. Wang, and X. Zhai. Gradient descent finds global minima of deep neural networks. ICML, arXiv:1811.03804, 2018. [14] S. S. Du, X. Zhai, B. Poczos, and A. Singh. Gradient descent provably optimizes over-parameterized neural networks. In ICLR, arXiv:1810.02054, 2018. [15] R. Ge, F. Huang, C. Jin, and Y. Yuan. Escaping from saddle points—online stochastic gradient for tensor decomposition. In Proceedings of The 28th Conference on Learning Theory, pages 797–842, 2015. [16] R. Ge, J. D. Lee, and T. Ma. Matrix completion has no spurious local minimum. In Advances in Neural Information Processing Systems, pages 2973–2981, 2016. [17] B. Ghorbani, S. Mei, T. Misiakiewicz, and A. Montanari. Linearized two-layers neural networks in high dimension. arXiv preprint arXiv:1904.12191, 2019. [18] X. Glorot and Y. Bengio. Understanding the difficulty of training deep feedforward neural networks. In Proceedings of the thirteenth international conference on artificial intelligence and statistics, pages 249–256, 2010. [19] G. Hinton, L. Deng, D. Yu, G. Dahl, A.-r. Mohamed, N. Jaitly, A. Senior, V. Vanhoucke, P. Nguyen, B. Kingsbury, et al. Deep neural networks for acoustic modeling in speech recognition. IEEE Signal processing magazine, 29, 2012. [20] A. Jacot, F. Gabriel, and C. Hongler. Neural tangent kernel: Convergence and generalization in neural networks. In Advances in neural information processing systems, pages 8571–8580, 2018. [21] C. Jin, R. Ge, P. Netrapalli, S. M. Kakade, and M. I. Jordan. How to escape saddle points efficiently. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pages 1724–1732. JMLR. org, 2017.

15 [22] K. Kawaguchi. Deep learning without poor local minima. In Advances in Neural Information Processing Systems, pages 586–594, 2016. [23] K. Kawaguchi and J. Huang. Gradient descent finds global minima for generalizable deep neural net- works of practical sizes. arXiv preprint arXiv:1908.02419, 2019. [24] K. Kawaguchi and L. P. Kaelbling. Elimination of all bad local minima in deep learning. arXiv preprint arXiv:1901.00279, 2019. [25] A. Krizhevsky, I. Sutskever, and G. E. Hinton. Imagenet classification with deep convolutional neural networks. In Advances in neural information processing systems, pages 1097–1105, 2012. [26] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. Gradient-based learning applied to document recog- nition. Proceedings of the IEEE, 86(11):2278–2324, 1998. [27] J. Lee, L. Xiao, S. S. Schoenholz, Y. Bahri, J. Sohl-Dickstein, and J. Pennington. Wide neural networks of any depth evolve as linear models under gradient descent. arXiv preprint arXiv:1902.06720, 2019. [28] J. D. Lee, M. Simchowitz, M. I. Jordan, and B. Recht. Gradient descent only converges to minimizers. In Conference on learning theory, pages 1246–1257, 2016. [29] Y. Li and Y. Liang. Learning overparameterized neural networks via stochastic gradient descent on structured data. In Advances in Neural Information Processing Systems, pages 8157–8166, 2018. [30] S. Liang, R. Sun, J. D. Lee, and R. Srikant. Adding one neuron can eliminate all bad local minima. In Advances in Neural Information Processing Systems, 2018. [31] S. Mei, T. Misiakiewicz, and A. Montanari. Mean-field theory of two-layers neural networks: dimension- free bounds and kernel limit. arXiv preprint arXiv:1902.06015, 2019. [32] P.-M. Nguyen. Mean field limit of the learning dynamics of multilayer neural networks. arXiv preprint arXiv:1902.02880, 2019. [33] D. Park, A. Kyrillidis, C. Caramanis, and S. Sanghavi. Non-square matrix sensing without spurious local minima via the burer-monteiro approach. arXiv preprint arXiv:1609.03240, 2016. [34] T. N. Sainath, A.-r. Mohamed, B. Kingsbury, and B. Ramabhadran. Deep convolutional neural networks for lvcsr. In 2013 IEEE international conference on acoustics, speech and signal processing, pages 8614– 8618. IEEE, 2013. [35] D. Silver, A. Huang, C. J. Maddison, A. Guez, L. Sifre, G. Van Den Driessche, J. Schrittwieser, I. Antonoglou, V. Panneershelvam, M. Lanctot, et al. Mastering the game of go with deep neural networks and tree search. Nature, 529(7587):484–489, 2016. [36] D. Silver, J. Schrittwieser, K. Simonyan, I. Antonoglou, A. Huang, A. Guez, T. Hubert, L. Baker, M. Lai, A. Bolton, et al. Mastering the game of go without human knowledge. Nature, 550(7676):354, 2017. [37] J. Sirignano and K. Spiliopoulos. Mean field analysis of deep neural networks. arXiv preprint arXiv:1903.04440, 2019. [38] M. Song, A. Montanari, and P. Nguyen. A mean field view of the landscape of two-layers neural networks. Proceedings of the National Academy of Sciences, 115:E7665–E7671, 2018.

16 [39] Z. Song and X. Yang. Quadratic suffices for over-parametrization via matrix chernoff bound. arXiv preprint arXiv:1906.03593, 2019. [40] J. Sun, Q. Qu, and J. Wright. Complete dictionary recovery over the sphere i: Overview and the geometric picture. IEEE Transactions on Information Theory, 63(2):853–884, 2016. [41] J. Sun, Q. Qu, and J. Wright. A geometric analysis of phase retrieval. Foundations of Computational Mathematics, 18(5):1131–1198, 2018. [42] C. Szegedy, W. Liu, Y. Jia, P. Sermanet, S. Reed, D. Anguelov, D. Erhan, V. Vanhoucke, and A. Ra- binovich. Going deeper with convolutions. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 1–9, 2015. [43] R. Vershynin. Introduction to the non-asymptotic analysis of random matrices. In Compressed sensing, pages 210–268. Cambridge Univ. Press, Cambridge, 2012. [44] Y. Wu, M. Schuster, Z. Chen, Q. V. Le, M. Norouzi, W. Macherey, M. Krikun, Y. Cao, Q. Gao, K. Macherey, et al. Google’s neural machine translation system: Bridging the gap between human and machine translation. arXiv preprint arXiv:1609.08144, 2016. [45] G. Yang. Scaling limits of wide neural networks with weight sharing: Gaussian process behavior, gradient independence, and neural tangent kernel derivation. CoRR, abs/1902.04760, 2019. [46] G. Yehudai and O. Shamir. On the power and limitations of random features for understanding neural networks. arXiv preprint arXiv:1904.00687, 2019. [47] D. Zou, Y. Cao, D. Zhou, and Q. Gu. Stochastic gradient descent optimizes over-parameterized deep relu networks. arXiv preprint arXiv:1811.08888, 2018.

A Initial Estimates

(r+1) We have derived the dynamic (2.2) of the NTK in Section 3. The kernel Kt (xα1 , xα2 , , xαr , xβ) is the (r) ··· sum of all the possible terms from Kt (xα1 , xα2 , , xαr ) by performing any of the replacements in (3.1) with α = α , α , , α . We recall the sets D from··· Section 3, which are constructed recursively. Each 1 2 ··· r r vector in Dr contains exact r diag operations. We have the following proposition on the structures of vectors in Dr.

Proposition A.1. Given any expression v(t) Dr with some r > 0, new expressions obtained from v(t) by performing one of the replacements in (3.1) are∈ sum of terms of the following forms: v (t) with v (t) D D ; • 1 1 ∈ r ∪ r+1 v √mx(k−1),√mx(k−1) 1(t) h α β i with 1 6 k 6 H and v (t) D ; • √m m 1 ∈ r+1 (ℓ) v v1(t) √mxβ , 2(t) h i with 1 6 ℓ 6 H and v1(t) Dr s+1 and v2(t) Ds for some s > 1; • √m m ∈ − ∈ ′ (ℓ+1) ⊤ ′ v σ (xβ )(W ) /√m σ (xβ )at,v2(t) 1(t) D ℓ t ··· H E with 1 6 ℓ 6 H, v1(t) Dr s and v2(t) Ds for some s > 1. • √m m ∈ − ∈

17 We remark that the time t in Proposition A.1 is only a parameter and this proposition does not involve dynamics.

Proof of Proposition A.1. By performing the replacement for at, the new expression is given by v1(t)/√m, with v (t) D . 1 ∈ r (ℓ) By performing the replacement for xα , we get a sum of ℓ terms. Each of them is of the form v (k 1) (k 1) v 1(t)/√m √mxα − , √mxβ − /m with 1(t) containing one more diag operations. It is easy to check that v (t)h D . i 1 ∈ r+1 (ℓ) By performing the replacement for Wt /√m, the new expression is given by (ℓ) v v (t) √mx , 2(t) 1 h β i , √m m

with v1(t) Dr s+1 and v2(t) Ds for some s > 1. ∈ − ∈ (ℓ) By performing the replacement for (Wt )⊤/√m, the new expression is given by v (ℓ+1) 1(t) 1 (Wt )⊤ v σℓ′ (xβ ) σH′ (xβ )at, 2(t) , √m m * √m ··· + with v1(t) Dr s and v2(t) Ds for some s > 1. ∈ − ∈ (u) v Finally, by performing the replacement for σℓ (xα), we get a sum of ℓ terms of the form 1(t)/√m, with v (t) D . 1 ∈ r+1

As a consequence of Proposition A.1, each summand in K(r)(x , x , , x ) is of the form t α1 α2 ··· αr s 1 v2j 1(t), v2j (t) h − i , 1 6 s 6 r, v1(t), v2(t), , v2s(t) D0 D1 Dr 2. (A.1) mr/2 1 m ··· ∈ ∪ ∪···∪ − − j=1 Y

(r) In the rest of this section we prove claim (2.3) in Theorem 2.3. To evaluate K0 (xα1 , xα2 , , xαr ), we use the tensor program in [45], which was developed to characterize the scaling limit of neural··· network computations. We show at time t = 0, those vectors vj (0) in (3.8) are combinations of projections of independent Gaussian vectors. As a consequence, we have that v2j 1(0), v2j (0) /m concentrates around s hv − v i certain constant with high probability. So does the product j=1 2j 1(0), 2j (0) /m. This gives the claim (2.3). h − i Q In the next section, we consider the quantity:

ξ(t) = max vj(t) : vj (t) D0 D1 Dp 1 . {k k∞ ∈ ∪ ∪···∪ − } C Again using the tensor program, we show that with high probability vj (0) . (ln m) . This gives the estimate of ξ(t) at t = 0. Next we show that the (p + 1)-th derivative of kξ(t) cank∞ be controlled by itself. This gives a self-consistent differential equation of ξ(t). Combining with the initial estimate of ξ(t), it follows

18 p C′ C C that for time 0 6 t 6 m 2(p+1) /(ln m) , it holds that ξ(t) . (ln m) . Especially vj (t) . (ln m) . Then the claim (2.4) in Theorem 2.3 follows. k k∞ Proposition A.2. Under Assumptions 2.1 and 2.2, there exists a deterministic family of operators K(r) : r R for 2 6 r 6 p +1 (independent of m) and K(r) =0 if r is odd, such that with high probability with Xrespect7→ to the random initialization, it holds that K(r) (ln m)C K(r) . . (A.2) 0 − mr/2 1 m(r 1)/2 − − ∞

(r) As we have shown in (A.1), the kernel K (x , x , , x ) is a sum of terms in the form t α1 α2 ··· αr s 1 v2j 1(t), v2j (t) h − i , 1 6 s 6 r, v1(t), v2(t), , v2s(t) D0 D1 Dr 2. (A.3) mr/2 1 m ··· ∈ ∪ ∪···∪ − − j=1 Y (r) To evaluate K0 (xα1 , xα2 , , xαr ), we recall the following conditioning Lemma from [45]. With this lemma, ··· (r) we can keep track of vectors appearing in the expression of K0 (xα1 , xα2 , , xαr ), and their decomposition into combinations of projections of independent Gaussian vectors. ··· m m Lemma A.3. Let W R × be a matrix with random Gaussian entries Wij (0,cw). Consider fixed m q ∈ m q m p m p ∼ N matrices Q R × , Y R × , P R × ,X R × . Then the distribution of W conditioned on Y = WQ ∈ ∈ ∈ ∈ and X = W ⊤P is d W ⊤ = E + Π⊥W˜ Π⊥, |Y =W Q,X=W P P Q where W˜ is an independent copy of W ,

+ + + + E = YQ + (P )⊤X⊤ΠQ⊥ = ΠP⊥YQ + (P )⊤X⊤,

+ + + Q , P are Moore-Penrose pseudoinverse of Q, P respectively, and Π = I Π⊥ = QQ , Π = I Π⊥ = Q m − Q P m − P P P + are the orthogonal projection on the space spanned by the columns of Q, P repsectively.

Proof of Proposition A.2. Without loss of generality, we simply take x = x , x = x , , x = x . We α1 1 α2 2 ··· αr r decompose the expression of K(r)(x , x , , x ) into sub-expressions. We denote 0 1 2 ··· r e e (ℓ) (ℓ 1) 0 = a0, r(ℓ 1)+i = W0 xi − , 1 6 i 6 r, 1 6 ℓ 6 H. −

In the rest of the proof, we view ei as formal expressions, and we denote their values as val(ei). For the computation, to evaluate f(x ,θ ),f(x ,θ ), ,f(x ,θ ), we need to sequentially evaluate the expressions 1 0 2 0 ··· r 0 e1, e2, , erH . We will express the values of these expressions as combinations of Gaussian vectors in the following··· way. By repeatedly using Lemma A.3, we have

e (1) d a val( 1)= W0 x1 = 1,1g1, val(e )= W (1)x =d a g + a g , 2 0 2 2,2 2 2,1 1 (A.4) ······ e (1) d a a a val( r)= W0 xr = r,rgr + r,r 1gr 1 + + r,1g1. − − ···

19 m where g1,g2, ,gr are independent standard Gaussian vectors in R ; the coefficients ai,j can be computed by performing··· the Gram-Schmidt algorithm over the input vectors x , x , , x , which depend only on the 1 2 ··· r inner products xi, xj and we call them A-variables. In general A-variables are random variables, however in (A.4), they areh deterministic.i Thanks to the Assumption 2.2, the smallest singular value of the matrix [x , x , , x ] is at least c > 0, the leading coefficients a , a , , a 1. As a consequence, each 1 2 ··· r r | 1,1| | 2,2| ··· | r,r|≍ of the evaluations of ei for 1 6 i 6 r contains a new standard Gaussian vector. For the output of the second layer, again using Lemma A.3, we have

W (2) val(e )= 0 σ(val(e )) =d a g , r+1 √m 1 r+1,r+1 r+1 W (2) val(e )= 0 σ(val(e )) =d a g + a g , r+2 √m 2 r+2,r+2 r+2 r+2,r+1 r+1

······ (2) W0 d val(e2r)= σ(val(er)) = a2r,2rg2r + a2r,2r 1g2r 1 + + a2r,r+1gr+1. √m − − ···

m where gr+1,gr+2, ,g2r are independent standard Gaussian vectors in R , which are also independent of ··· (1) (1) (1) g ,g , ,g ; the coefficients a are computed by performing the Gram-Schmidt algorithm over x , x , , xr . 1 2 ··· r i,j 1 2 ··· In this case, the coefficients a are random, which depend on the inner products x(1), x(1) . However, the i,j h i j i inner products x(1), x(1) h i j i (1) (1) 1 a a a a a a xi , xj = σ( i,igi + i,i 1gi 1 + + i,1g1), σ( j,j gj + j,j 1gj 1 + + j,1g1) , h i mh − − ··· − − ··· i are average of m independent identically distributed quantities, each is a function of Gaussian variables. (1) (1) Therefore, xi , xj has a scaling limit as the width of the neural network m , and strongly concen- trates aroundh this limit.i In other words, with high probability we have → ∞ (ln m)C lim ai,j = ˜ai,j, ai,j = ˜ai,j +O . (A.5) m √m →∞   We will see soon, in fact, by the same reasoning, all the A-variables appearing in this section satisfy the (1) (1) (1) relation (A.5). Moreover, in the limit m , The Gram matrix of x1 , x2 , , xr is full rank. Otherwise there exist constants λ , λ , , λ such→ that ∞ ··· 1 2 ··· r λ1σ(˜a1,1G1)+ λ1σ(˜a2,2G2 + ˜a2,1G1)+ + λrσ(˜ar,rGr + ˜ar,r 1Gr 1 + + ar,1G1)=0, (A.6) ··· − − ··· for independent Gaussian variables G , G , , G (0, 1). This is impossible, unless the expression 1 2 ··· r ∼ N (A.6) is literally zero, i.e. λ1, λ2, , λr = 0. Therefore, in the limit m , The Gram matrix of (1) (1) (1) ··· a a a→ ∞ x1 , x2 , , xr is full rank. We conclude that ˜r+1,r+1 , ˜r+2,r+2 , , ˜2r,2r 1. Combining with (A.5), with··· high probability, it holds a , a | , | ,| a | 1.··· Again,| each| ≍ of the evaluations of | r+1,r+1| | r+2,r+2| ··· | 2r,2r|≍ ei for r +1 6 i 6 2r contains a new standard Gaussian vector. By repeating the above argument, we get that for any 1 6 i 6 rH,

d val(ej ) = aj,jgj + aj,j 1gj 1 + aj,1g1, − − ···

20 where g1,g2, ,grH are independent standard Gaussian vectors, the A-variables aj,j , aj,j 1, , aj,1 con- centrate around··· their limits, i.e. with high probability (A.5) holds, and a 1. − ··· | j,j |≍ Formally as expressions, we have for 2 6 ℓ 6 H, σ(e ) σ(e ) σ(e ) e e e (ℓ) r(ℓ 2)+1 r(ℓ 2)+2 r(ℓ 1) r(ℓ 1)+1, r(ℓ 1)+2, , rℓ = W0 − , − , , − . − − ··· √m √m ··· √m     To use Lemma A.3 in the future, we denote for 2 6 ℓ 6 H,

(ℓ) e e e Y0 = val r(ℓ 1)+1, r(ℓ 1)+2, , rℓ , − − ··· e e e (ℓ) σ( r(ℓ 2)+1) σ( r(ℓ 2)+2) σ( r(ℓ 1)) Q = val  − , − , , − . 0 √m √m ··· √m   (ℓ) (ℓ) (ℓ) Then Y0 = W0 Q0 , for 2 6 ℓ 6 H. To estimate K(r)(x , x , , x ), we need to decompose the expression of K(r)(x , x , , x ) into subex- 0 1 2 ··· r 0 1 2 ··· r pressions e , e , e , in the following way. Since each summand in K(r)(x , x , , x ) is of rH+1 rH+2 rH+3 ··· 0 1 2 ··· r the form (3.8). For each of these vectors vj (0), uj (0), we evaluate it from right to left. Each time, when (2) (2) (3) (3) (H) (H) we need to multiply one of these matrices W0 , (W0 )⊤, W0 , (W0 )⊤, , W0 , (W0 )⊤, we add a new subexpression corresponding to the whole expression if it has not appeared··· before. For example we have the following expression in K(2)( , ): 0 · · (ℓ+1) (W0 )⊤ σ′ (x) σ′ (x)a . (A.7) ℓ √m ··· H 0 We decompose it into subexpressions in the following way

(H) (W0 )⊤ e = σ′ (x)a , rH+1 √m H 0 (H 1) (H) e (W0 − )⊤ (W0 )⊤ rH+2 = σH′ 1(x) σH′ (x)a0, √m − √m

······ (ℓ+1) e (W0 )⊤ (r+1)H ℓ = σH′ (x)a0 = (A.7). − √m ···

e e (ℓ) f e (ℓ) f Then for each j with j > rH + 1, either j = (W0 /√m) j or j = ((W0 )⊤/√m) j for some 2 6 ℓ 6 H, and fj is an expression in the following form

(s) Mul(e0, e1, , ej 1)= entrywise products of e0, σ (ei) 06s6r 1,16i6rH , ei rH+16i6j 1 . ··· − { { } − { } − } For 2 6 ℓ 6 H, we denote the sets

S(ℓ) = 1 6 j 6 rH + τ : e ends with multiplying W (ℓ)/√m , τ { j 0 } (ℓ) (ℓ) T = 1 6 j 6 rH + τ : e ends with multiplying (W )⊤/√m . τ { j 0 }

21 Formally as expressions, we have for 2 6 ℓ 6 H,

(ℓ) fj [ej ] (ℓ) = W , j Sτ 0 ∈ √m j S(ℓ)   ∈ τ f e (ℓ) j [ j ]j T (ℓ) = (W0 )⊤ . τ √m (ℓ) ∈ j Tτ   ∈ To use Lemma A.3 in the future, we denote for 2 6 ℓ 6 H, f (ℓ) (ℓ) j Y = val [ej ] (ℓ) ,Q = val , τ j Sτ τ ∈ √m j S(ℓ)   ∈ τ f (A.8) (ℓ) e (ℓ) j Xτ = val [ j ]j T (ℓ) , Pτ = val . τ √m (ℓ) ∈ j Tτ   ∈ Then for 2 6 ℓ 6 H,

(ℓ) (ℓ) (ℓ) (ℓ) (ℓ) (ℓ) Yτ = W0 Qτ , Xτ = (W0 )⊤Pτ .

In the following we prove by induction that

Claim A.4. For τ > 1, the following holds.

(ℓ) (ℓ) (i) The limits as m of the Gram matrices of columns Qτ , and columns of Pτ , as defined in (A.8), are non-degenerate;→ ∞

(s) (ii) Let Mul(a0,g1,g2, ,grH+τ 1) be the set of entrywise products of a0, σ (ai,igi + ai,i 1gi 1 + + ··· − { − − ··· ai,1g1) 06s6r 1,16i6rH and gi rH+16i6rH+τ 1 and LinMul(a0,g1,g2, ,grH+τ 1) be the set of lin- } − { } − ··· − ear combinations of Mul(a0,g1,g2, ,grH+τ 1) with A-variables as coefficients. The evaluation of ··· − erH+τ has the following form

val(erH+τ )= arH+τ,rH+τ grH+τ + LinMul(a0,g1,g2, ,grH+τ 1), ··· − where g˜ (0, I ) is the standard Gaussian vector, rH+τ ∼ N m

grH+τ = Π⊥(ℓ) g˜rH+τ , Qτ

e (ℓ) if the expression rH+τ ends with multiplying W0 /√m and

grH+τ = Π⊥(ℓ) g˜rH+τ , Pτ

e (ℓ) if the expression rH+τ ends with multiplying (W0 )⊤/√m. Moreover, with high probability we have

(ln m)C lim arH+τ,rH+τ = ˜arH+τ,rH+τ =0, arH+τ,rH+τ = ˜arH+τ,rH+τ +O . m 6 √m →∞  

22 Proof of Claim A.4. We assume that the statements of Claim A.4 hold up to τ and prove it for τ + 1 e (ℓ) e Without loss of generality, we assume that rH+τ+1 ends with multiplying W0 /√m, then rH+τ+1 = (ℓ) f f e e e (ℓ) (ℓ) (ℓ) (W0 /√m) rH+τ+1, and rH+τ+1 Mul( 0, 1, , rH+τ ). Moreover, Sτ+1 = Sτ rH + τ +1 , Tτ+1 = (ℓ) ∈ ··· ∪{ } Tτ , and for 2 6 ℓ 6 H, f f (ℓ) e (ℓ) e (ℓ) j (ℓ) val( τ ) Yτ+1 = val [ j ]j S(ℓ) = [Yτ , val( τ+1)],Qτ+1 = val = Qτ , , ∈ τ+1 √m j S(ℓ) √m   ∈ τ+1   f (A.9) (ℓ) e (ℓ) (ℓ) j (ℓ) Xτ+1 = val [ j ]j T (ℓ) = Xτ , Pτ+1 = val = Pτ . ∈ τ+1 √m j T (ℓ)   ∈ τ+1 (ℓ) By our induction assumption, we have that the limits as m of the Gram matrix of columns of Pτ+1 is non-degenerate. To prove (i) in Claim A.4, we only need→ to ∞ show that the limits as m of the (ℓ) → ∞ Gram matrix of columns of Qτ+1 is non-degenerate. We prove it by contradiction. We recall from (A.9) (ℓ) f (ℓ) Qτ+1 = val[ j /√m]j S(ℓ) . If the limit of the Gram matrix of columns of Qτ+1 is degenerate, informally, ∈ τ+1 there exists constants λj such that

lim λj val(fj )+ λrH+τ+1val(frH+τ+1)=0. (A.10) m →∞ (ℓ) j Sτ ∈X (s) We recall that fj is an expression of entrywise products of e0, σ (ei) 06s6r 1,16i6rH and ei rH+16i6j 1, { } − { } − and by our induction hypothesis ei = ai,igi+LinMul(a0,g1,g2, ,gi 1), with ai,i 1 with high probability. ··· − | |≍ Moreover, as m , the vectors a0,g1,g2, ,grH+τ converge to independent standard Gaussian vectors. (A.10) implies that→ ∞ as formal expressions ···

λj fj + λrH+τ+1frH+τ+1 =0. (ℓ) j Sτ ∈X However, this indicates that frH+τ+1 = λfj with some j 6 rH + τ and contradicts with our construction that erH+τ+1 has not appeared before. This finishes the proof of (i) in Claim A.4. For the proof of (ii) in Claim A.4, thanks to Lemma A.3, we have f e d (ℓ) (ℓ) + (ℓ) + (ℓ) ˜ val( rH+τ+1) val( rH+τ+1) = Yτ (Qτ ) + ((Pτ ) )⊤(Xτ )⊤Π⊥(ℓ) + Π⊥ (ℓ) + W Π⊥(ℓ) . (A.11) Qτ (Pτ ) Qτ √m   (s) Since frH+τ+1 Mul(e0, e1, , erH+τ ) is an expression of entrywise products of e0, σ (ei) 06s6r 1,16i6rH ∈ ··· { } − and e 6 6 ,, and by our induction assumption for 1 6 i 6 rH + τ, { i}rH+1 i rH+τ val(ei)= ai,igi + LinMul(a0,g1,g2, ,gi 1), ··· − we conclude that val(f ) LinMul(a ,g ,g , ,g ). rH+τ+1 ∈ 0 1 2 ··· rH+τ (ℓ) (ℓ) By our induction assumption, the columns of Qτ and Pτ as m are of full rank. The first two terms (ℓ) → ∞(ℓ) in (A.11) are linear combinations of columns of Yτ and columns of Pτ with A-variables as coefficients: f (ℓ) (ℓ) + (ℓ) + (ℓ) val( rH+τ+1) Yτ (Qτ ) + ((Pτ ) )⊤(Xτ )⊤Π⊥(ℓ) LinMul(a,g1,g2, ,grH+τ ). (A.12) Qτ √m ∈ ···  

23 For the last term in (A.11), we can rewrite it as

˜ f a Π⊥ (ℓ) + W Π⊥(ℓ) val( rH+τ+1)= rH+τ+1,rH+τ+1Π⊥ (ℓ) + g˜rH+τ+1, (A.13) (Pτ ) Qτ (Pτ ) whereg ˜rH+τ+1 is an independent Gaussian vector and

1 f arH+τ+1,rH+τ+1 = Π⊥(ℓ) val( rH+τ+1) 2. √mk Qτ k

By the same argument as before, the A-variable arH+τ+1,rH+τ+1 strongly concentrates around this limit. With high probability we have

(ln m)C lim arH+τ+1,rH+τ+1 = ˜arH+τ+1,rH+τ+1, arH+τ+1,rH+τ+1 = ˜arH+τ+1,rH+τ+1 +O . m √m →∞   Moreover, as we just proven, (i) in Claim A.4 implies that as m , the limit of the Gram matrix of → ∞ val(fj )/√m (ℓ) val(frH+τ+1)/√m is non-degenerate. We conclude that ˜arH+τ+1rH+τ+1 = 0, then { }j Sτ ∪{ } 6 with high probability∈ a 1. This finishes the proof of Claim A.4. | rH+τ+1,rH+τ+1|≍

From the discussion above, the evaluation of any subexpression e in K(r)(x , x , , x ) is of the form i 0 1 2 ··· r

ei = ai,igi + LinMul(a0,g1,a2, ,gi 1). (A.14) ··· −

Especially, the vectors vj (t) at time t = 0 in (A.1) are also of the form (A.14). Their inner products concentrate around their limits as m , → ∞ C v2j 1(0), v2j (0) v2j 1(0), v2j (0) (ln m) h − i = lim h − i +O . (A.15) m m m √m →∞   There exists a deterministic operator K(r) : r R for 2 6 r 6 p + 1, it holds that with high probability X 7→ r/2 1 (r) (r) C m − K (x , x , , x ) K (x , x , , x ) . (ln m) . 0 1 2 ··· r − 1 2 ··· r ∞ By an union bound over all r-tuple of data points (x , x , , x ), we conclude that with high probability α1 α2 ··· αr K(r) (ln m)C K(r) . . 0 − mr/2 1 m(r 1)/2 − − ∞

(r) (r) K(r) If 2 ∤ r, then the degree of a0 in K0 is odd, we have E[K0 ] = 0. It is necessary that = 0. This finishes the proof of Proposition A.2.

Corollary A.5. Under Assumptions 2.1 and 2.2, for any expression v(t) Dr with 0 6 r 6 2p, the following holds with high probability ∈

C v(0) . (ln m) . k k∞

24 Proof. By the same argument as in the proof of Proposition A.2, we can evaluate v(0) as combinations of standard Gaussian vectors v(0) LinMul(a ,g ,g , ), ∈ 0 1 2 ··· where the set LinMul is as defined in Claim A.4: LinMul(a ,g ,g , ) is the set of linear combinations 0 1 2 ··· of Mul(a0,g1,g2, ) with A-variables as coefficients; Mul(a0,g1,g2, ) is the set of entrywise products (s) ··· ··· of a0, σ (ai,igi + ai,i 1gi 1 + + ai,1g1) 06s6r+1,16i6rH and gi i>rH+1. Since those vectors gi are { − − ··· } { } C projections of independent Gaussian vectors, with high probability gi . (ln m) . So is any vector in LinMul(a ,g ,g , ). k k∞ 0 1 2 ···

B A Priori Estimates

In this section, we prove the claim (2.4) in Theorem 2.3. Proposition B.1 (A priori L2 bounds). Under Assumptions 2.1 and 2.2, for any time t > 0, we have n n f (t) y 2 6 f (0) y 2 = O(n), (B.1) | β − β| | β − β| βX=1 βX=1 and with high probability with respect to the random initialization, for t . √m

1 (1) (2) (2) (H) (H) max Wt 2 2, Wt 2 2, (Wt )⊤ 2 2 , Wt 2 2, (Wt )⊤ 2 2, at 2 . 1. (B.2) √m {k k → k k → k k → ··· k k → k k → k k }

(2) Proof of Proposition B.1. From the defining relation (1.7) of Kt ( , ), it is non-negative definite. Using (1.6), we get · · n n ∂ f (t) y 2 = K(2)(x , x )(f (t) y )(f (t) y ) 6 0, t k β − βk − t α β α − α β − β βX=1 α,βX=1 and (B.1) follows n n f (t) y 2 6 f (0) y 2 = O(n), k β − βk k β − βk βX=1 βX=1 To prove (B.2), we define

1 (1) (2) (2) (H) (H) ξ(t)= max Wt 2 2, Wt 2 2, (Wt )⊤ 2 2 , Wt 2 2, (Wt )⊤ 2 2, at 2 . √m {k k → k k → k k → ··· k k → k k → k k }

We notice that for t = 0, W (1) is an m d random gaussian matrix, W (2), W (3), , W (H) are m m random 0 × 0 0 ··· 0 × gaussian matrices, and a0 is a gaussian vector of length m. From random matrix theory [43], we have that, with high probability, ξ(0) . 1. (B.3)

25 In the following we derive an upper bound of ∂tξ(t), which combining with (B.3) gives us the desired bound (B.2). For any ℓ > 1, we have

m (ℓ) 1 (ℓ) (ℓ 1) 1 (ℓ) (ℓ 1) 2 x = σ(W x − ) 6 ( σ(0) + c (W x ) ) k k2 √mk k2 √mv | | 1 − i ui=1 c uX 1 (ℓ) (ℓ 1) t (ℓ 1) 6 σ(0) + W x − 6 σ(0) + c ξ(t) x − , | | √mk k2 | | 1 k k2

where we used Assumption 2.1 that σ is c1-Lipschitz, and Assumption 2.2 x 2 . 1. Inductively, we have the following estimate k k (ℓ) ℓ ℓ ℓ 1 ℓ 1 ℓ ℓ x 6 c ξ(t) x + σ(0) (1 + c ξ(t)+ + c − ξ(t) − ) . c ξ(t) . (B.4) k k2 1 k k2 | | 1 ··· 1 1 Using (1.4), (1.5) and (B.4), we have the following bounds:

n (ℓ+1) (ℓ) 1 (Wt )⊤ at (ℓ 1) ∂t Wt 2 2 6 σℓ′ (xβ) σH′ (xβ ) xβ − 2 fβ(t) yβ k k → n √m ··· √m k k | − | β=1 2 X (B.5) n n 1 H H H H 1 2 H H . c ξ(t) fβ(t) yβ . c ξ(t) fβ(t) yβ . c ξ(t) , n 1 | − | 1 vn | − | 1 β=1 u β=1 X u X t and 1 n 1 n ∂ a(t) 6 x(H) f (t) y . cH ξ(t)H (1 + x ) f (t) y tk k2 n k β k2| β − β| n 1 k βk2 | β − β| β=1 β=1 X X (B.6) n H H 1 2 H H . c ξ(t) fβ(t) yβ . c ξ(t) , 1 vn | − | 1 u β=1 u X t where we used the AM-GM inequality and (B.1). And similarly, (ℓ) cH H ∂t (Wt )⊤ 2 2 . 1 ξ(t) . (B.7) k k →

The estimates (B.5), (B.6) and (B.7) together implies the upper bounds for ∂tξ(t): there exists some large constant C > 0 CcH ∂ ξ(t) 6 1 ξ(t)H . t √m If H = 1, we have

CcH √ ξ(t) 6 e 1 t/ mξ(0), and for H > 2, we get

H 1 H 1/(H 1) ξ(t) 6 ξ(0) − Cc t/√m − − . − 1 In both cases, we have that ξ(t) . 1 provided that t . √m, where the implicit constants depend on the depth H. This finishes the proof of Proposition B.1.

26 (r) As we have shown in Section 3 (3.8), the kernel Kt (xα1 , xα2 , , xαr ) is a sum of terms in the form ··· s 1 v2j 1(t), v2j (t) h − i , 1 6 s 6 r, vj (t) D0 D1 Dr 2. (B.8) mr/2 1 m ∈ ∪ ∪···∪ − − j=1 Y

In the following we derive an upper bound of diag(ft) 2 2 = ft , where diag( ) is the diagonalization of a vector, for any f D D D . k k → k k∞ ··· t ∈ 0 ∪ 1 ∪···∪ r Proposition B.2. We assume Assumptions 2.1 and 2.2. Fix time t > 0 and r > 0. Suppose that for all expressions f D D D , the following holds t ∈ 0 ∪ 1 ∪···∪ r

diag(ft) 2 2 6 M, (B.9) k k → for some constant M 1. Then for any f D with 0 6 s 6 2r +2, the following holds ≥ t ∈ s s diag(ft) 2 2 6 ft 2 . M √m. (B.10) k k → k k

Proof of Proposition B.2. We notice that diag(ft) 2 2 equals the L norm of the vector ft. The L norm k k → ∞ ∞ of ft is bounded by its L2 norm,

diag(ft) 2 2 = ft 6 ft 2, k k → k k∞ k k which gives the first inequality of (B.10). Notice that the last inequality above is not optimal and it costs a factor √m generically.

For ft Ds with some s 6 2r + 2. We can write it as ft = ekek 1 e1e0 where 0 6 k 6 4H 3, ∈ − ··· − (1) (2) (H) e a , 1, √mx , √mx , , √mx 6 6 , 0 ∈ t { β β ··· β }1 β n n o and for 1 6 j 6 k, ej belongs to one of the sets

(2) (2) (H) (H) Wt (Wt )⊤ Wt (Wt )⊤ , , , , , σ1′ (xβ ), σ2′ (xβ ), , σH′ (xβ) 16β6n , (B.11) (( √m √m ··· √m √m ) { ··· } )

diag(d), d D0 D1 Ds 1 , (B.12) { ∈ ∪ ∪···∪ − } (u+1) d d d σℓ (xβ) diag( 1) diag( 2) diag( u):1 6 ℓ 6 H, ··· (B.13) n 1 6 β 6 n, 1 6 u 6 s 1, d1, d2, , du D0 D1 Ds 1 . − ··· ∈ ∪ ∪···∪ − o Moreover, the total number of diag operations in the expression ft = ekek 1 e1e0 is exactly s. We remark (ℓ) − ··· that xβ depends on time t. In the following we prove by induction on s that

s ft 2 = ekek 1 e1e0 2 . M √m, (B.14) k k k − ··· k which gives the claim (B.10).

27 For s = 0, ft does not contain diag operations, and all the ej belong to (B.11). In this case, thanks to Assumption 2.1 and Proposition B.1, with high probability with respect to the random initialization, e0 2 . √m and ej 2 2 . 1 for all 1 6 j 6 k. Therefore, we have that k k k k → ft 2 = ekek 1 e1e0 2 6 ek 2 2 ek 1 2 2 e1 2 2 e0 2 . √m. k k k − ··· k k k → k − k → ···k k → k k For 1 6 s 6 r, by our assumption (B.9)

s ft 2 = ekek 1 e1e0 2 6 √m ekek 1 e1e0 = √m diag(ft) 2 2 6 M√m 6 M √m. k k k − ··· k k − ··· k∞ k k → Thus the claim (B.14) holds for any s 6 r. In the following we assume that (B.14) holds for 1, 2, ,s 1 and prove it for s. For each 1 6 j 6 k, ··· − we denote the number of diag operations in ej by sj . Then the total number of diag operations in ft is s + s + + s = s 6 2r + 2. As an easy consequence, s 6 s 6 2r +2 for any 1 6 j 6 k. For each term 1 2 ··· k j ej , there are several cases:

(i) ej belongs to (B.11), then it does not contain any diag operation, and we have sj = 0. In this case, we have proven that ej 2 2 . 1. k k → (ii) ej = diag(d) belongs to (B.12) and sj 6 r+1. Since the expression d contains sj 1 6 r diag operations, d D D D D − also belongs to sj 1 0 1 r. By the assumption (B.9) − ⊂ ∪ ∪···∪ ej 2 2 = diag(d) 2 2 6 M. k k → k k →

(iii) ej = diag(d) belongs to (B.12) and sj > r + 1. In this case we still have that sj 1 6 s 1. And by d D − − our induction assumption (B.14) and sj 1, it holds ∈ − sj 1 d . M − √m. (B.15) k k2

(iv) e = σ(u+1)(x ) diag(d ) diag(d ) diag(d ) belongs to (B.13), and each of those subexpressions j ℓ β 1 2 ··· u diag(d1), diag(d2), , diag(du) contains at most r + 1 diag operations. In this case we have u 6 sj and d , d , , d ···D D D . By the assumption (B.9) and Assumption 2.1 1 2 ··· u ∈ 0 ∪ 1 ∪···∪ r (u+1) ej 2 2 = σ (xβ) 2 2 diag(d1) 2 2 diag(d2) 2 2 diag(du) 2 2 k k → k ℓ k → k k → k k → ···k k → u sj 6 cu+1M . M

(v) e = σ(u+1)(x ) diag(d ) diag(d ) diag(d ) belongs to (B.13), and some of those subexpressions j ℓ β 1 2 ··· u diag(d1), diag(d2), , diag(du) contain more than r + 1 diag operations. In this case sj > r + 1. Since the total number of··· diag operations in e is s 6 2r + 2, exact one of diag(d ), diag(d ), , diag(d ) j j 1 2 ··· u contains more than r + 1 diag operations. Say it is diag(dv). For any 1 6 i = v 6 u, diag(di) contains at most r + 1 diag operations and d D D D . By the assumption6 (B.9) i ∈ 0 ∪ 1 ∪···∪ r di = diag(di) 2 2 6 M. (B.16) k k∞ k k → For dv, it contains at most sj u 6 s 1 diag operations. Thus by our induction assumption (B.14), it holds − −

sj u d . M − √m. (B.17) k vk2

28 By our assumption, the total number of diag operations in ft = ekek 1 e0 is s1 + s2 + + sk = s 6 − ··· ··· 2r +2. At most one of those sj is bigger than r + 1. Especially at most one of those ej belongs to cases (iii) or (v). If none of those e belongs to cases (iii) or (v), then with the bound e . √m, we have j k 0k2 k si s ekek 1 e1e0 2 6 ek 2 2 ek 1 2 2 e1 2 2 e0 2 . M √m = M √m. (B.18) k − ··· k k k → k − k → ···k k → k k i=1 Y

If for some 1 6 j 6 k, ej belongs to the case (iii), we write

ekek 1 e1e0 2 6 ek 2 2 ek 1 2 2 ej+1 2 2 ejej 1 e1e0 2 k − ··· k k k → k − k → ···k k → k − ··· k (B.19) 6 ek 2 2 ek 1 2 2 ej+1 2 2 d 2 ej 1ej 2 e1e0 . k k → k − k → ···k k → k k k − − ··· k∞

The expression ej 1ej 2 e1e0 contains at most s1 + s2 + + sj 1 6 r diag operators, and j 1 6 − − ··· e e e e ··· − D − k 1 < 4H 3. Therefore, the expression j 1 j 2 1 0 is in the set s1+s2+ +sj−1 , which is contained in−D D − D . Thus by our assumption− (−B.9···), ··· 0 ∪ 1 ∪···∪ r

ej 1ej 2 e1e0 = diag(ej 1ej 2 e1e0) 2 2 6 M. (B.20) k − − ··· k∞ k − − ··· k →

We estimate d 2 using (B.15), and estimate ek 2 2 ek 1 2 2 ej+1 2 2 by (i), (ii) and (iv). By plugging (B.15k),k (B.20) into (B.19), we get k k → k − k → ···k k →

ek 2 2 ek 1 2 2 ej+1 2 2 d 2 ej 1ej 2 e1e0 k k → k − k → ···k k → k k k − − ··· k∞ k si sj 1 s . M (M − √m)M . M √m. i=j+1 Y

If for some 1 6 j 6 k, ej belongs to the case (iii), we write

ekek 1 e1e0 2 6 ek 2 2 ek 1 2 2 ej+1 2 2 ejej 1 e1e0 2 k − ··· k k k → k − k → ···k k → k − ··· k (u+1) (B.21) = ek 2 2 ek 1 2 2 ej+1 2 2 σ (xβ) diag(d1) diag(d2) diag(du)ej 1 e1e0 2. k k → k − k → ···k k → k ℓ ··· − ··· k For the last term in (B.21), we have

(u+1) σ (xβ) diag(d1) diag(d2) diag(du)ej 1 e1e0 2 k ℓ ··· − ··· k (u+1) 6 σ (xβ) diag(d1) diag(dv 1) 2 2 dv 2 diag(dv+1) diag(du)ej 1 e1e0 (B.22) k ℓ ··· − k → k k k ··· − ··· k∞ (u+1) 6 σ (xβ) 2 2 d1 dv 1 dv 2 dv+1 du ej 1 e1e0 . k ℓ k → k k∞ ···k − k∞k k k k∞ ···k k∞k − ··· k∞ Plugging (B.16), (B.17) and (B.20) into (B.22), we get

(u+1) σ (xβ) diag(d1) diag(d2) diag(du)ej 1 e1e0 2 k ℓ ··· − ··· k (B.23) u 1 u 1 sj u sj . M − dv 2 ej 1 e1e0 2 . M − M − √mM = M √m. k k k − ··· k

29 We estimate ek 2 2 ek 1 2 2 ej+1 2 2 by (i), (ii) and (iv). By plugging (B.23) into (B.21), we get k k → k − k → ···k k →

ek 2 2 ek 1 2 2 ej+1 2 2 ej ej 1ej 2 e1e0 k k → k − k → ···k k → k − − ··· k∞ k . M si (M sj √m) . M s√m. i=j+1 Y This finishes the proof of (B.14), and hence Proposition B.2. Notice that the inequalities (B.15) and (B.17) which contain the factor √m were used only once in this proof.

Proposition B.3. We assume Assumptions 2.1 and 2.2. With high probability, uniformly for any vector p ′ p C ft D0 D1 Dp 1, and time 0 6 t 6 m 2( +1) /(ln m) the following holds ∈ ∪ ∪···∪ − C ft . (ln m) . k k∞

Proof of Proposition B.3. Thanks to Corollary A.5, with high probability, uniformly for all f D D t ∈ 0 ∪ 1 ∪ Dp 1, we have that ···∪ − C f0 . (ln m) . k k∞ We denote

ξ(t) = max ft : ft D0 D1 Dp 1 . (B.24) {k k∞ ∈ ∪ ∪···∪ − } In the following we derive a self-consistent differential equation of ξ(t). Proposition B.3 follows from analyzing it. f D D D For any t(xα , xα , , xαs ) 0 1 p 1, by taking derivative we have 1 2 ··· ∈ ∪ ∪···∪ − n 1 1 ∂ f (x , x , , x )= f(1)(x , x , , x , x )(f (t) y ). (B.25) t t α1 α2 ··· αs −n √m t α1 α2 ··· αs β β − β βX=1

where f(1)(x , x , , x , x ) is obtained from f (x , x , , x ) by the replacements (3.1). We define, t α1 α2 ··· αs β t α1 α2 ··· αs

u v2j 1(t), v2j (t) LinProd(D D D )= linear combinations of v (t) h − i , 0 ∪ 1 ∪···∪ p+r  0 m  j=1  Y  where 0 6 u 6 r + 1, v (t) D , s ,s , ,s > 0 and s + s + + s 6 p + r. Thanks to Proposition j ∈ sj 0 1 ··· 2u 0 1 ··· 2u A.1, f(1) LinProd(D D D ). More generally, for any integer 1 6 r 6 p + 1, t ∈ 0 ∪ 1 ∪···∪ p n 1 1 ∂ f(r)(x , x , , x )= f(r+1)(x , x , , x , x )(f (t) y ), (B.26) t t α1 α2 ··· αs+r −n √m t α1 α2 ··· αs+r β β − β βX=1

30 where f(r+1) LinProd(D D D ). Using the bound (B.1), we have t ∈ 0 ∪ 1 ∪···∪ p+r 1 n 1 f(r+1)(x , x , , x , x )(f (t) y ) n √m t α1 α2 ··· αs+r β β − β β=1 X n ∞ 1 f(r+1) 1 6 max t (xα1 , xα2 , , xαs+r , xβ ) fβ (t) yβ √m 16β6n k ··· k∞ n | − | βX=1 1 1 n f(r+1) 2 6 max t (xα1 , xα2 , , xαs+r , xβ ) (fβ(t) yβ) √m 16β6n k ··· k∞vn − u β=1 u X t 1 f(r+1) . max t (xα1 , xα2 , , xαs+r , xβ ) . √m 16β6n k ··· k∞ Therefore, (B.25) and (B.26) together give

f 1 f(1) ∂t t(xα1 , xα2 , , xαs ) . max t (xα1 , xα2 , , xαs , xβ) , (B.27) k ··· k∞ √m 16β6n k ··· k∞ f(r) 1 f(r+1) ∂t t (xα1 , xα2 , , xαs+r ) . max t (xα1 , xα2 , , xαs+r , xβ) , (B.28) k ··· k∞ √m 16β6n k ··· k∞ for any 1 6 r 6 p + 1. By taking higher derivatives on both sides of (B.27), and using (B.28) to bound the righthand side, we have that (p+1) f ∂t t(xα1 , xα2 , , xαs ) k ··· k∞ 1 (p) f(1) . ∂t max t (xα1 , xα2 , , xαs , xβ1 ) . √m 16β16n k ··· k∞ ······ (B.29) 1 . f(p+1) (p+1)/2 max t (xα1 , xα2 , , xαs , xβ1 , xβ2 , , xβp+1 ) . m 16β1,β2, ,βp+16n k ··· ··· k∞ ···

f(p+1) From the discussion above, t is a linear combination of terms in the form u v2j 1(t), v2j (t) v (t) h − i , (B.30) 0 m j=1 Y v D where 0 6 u 6 p + 1, j (t) sj , s0,s1, ,s2u > 0 and s0 + s1 + + s2u 6 2p. We can use Proposition B.2 for r = p 1, ∈ ··· ··· − s0 v0(t) . ξ(t) √m, (B.31) k k∞ and for 1 6 j 6 2u v (t) . ξ(t)sj √m. (B.32) k j k2 f(p+1) The estimates (B.31) and (B.32) together give an upper bound for the L norm of t , ∞ u f(p+1) sj 2p t (xα1 , xα2 , , xαs , xβ1 , xβ2 , , xβp+1 ) . √m ξ(t) . ξ(t) √m. (B.33) k ··· ··· k∞ j=0 Y

31 We obtain a self-consistent differential equation of ξ(t) by taking maximum on both sides of (B.29) over 1 6 α , α , , α 6 n, and using (B.33) 1 2 ··· s ξ(t)2p ∂(p+1)ξ(t) . . (B.34) t mp/2

To obtain an upper bound of ξ(t) using (B.34), we still need an upper bound for the initial data, i.e. (r) ξ(0) and ∂t ξ(0) 16r6p. Fortunately Corollary A.5 provides such estimates. In fact, Corollary A.5 implies that with{ high probability} ξ(0) . (ln m)C. For the derivatives of ξ(t) at t = 0, we use (B.29)

(r) (r) f ∂t ξ(0) . max ∂t t(xα1 , xα2 , , xαs ) | | 16α1,α2, ,αs6n k ··· k∞ ··· t=0 1 . f(r) r/2 max max 0 (xα1 , xα2 , , xαs , xβ1 , xβ2 , , xβr ) . m 16α1,α2, ,αs6n 16β1,β2, ,βp+16n k ··· ··· k∞ ··· ··· f(r) v D Again 0 is a linear combination of terms in the form (B.30) with j (t) sj for some 0 6 u 6 r, ∈ C s0,s1, ,s2u > 0 and s0 +s1 + +s2u 6 p+r 1. Using Corollary A.5, for 0 6 j 6 2u, vj (0) . (ln m) . We conclude··· that ··· − k k∞

(ln m)(2r+1)C ∂(r)ξ(0) . , | t | mr/2 for any 1 6 r 6 p. The ordinary differential equation (B.34) has an exact solution in the following form:

p 2(2p−1) ˜ A0m ξ(t)= p+1 , p 2p−1 C 2p−1 A m 2(p+1) /(ln m) p+1 t 1 −   where A0, A1 are constants depending on p, which are chosen such that ξ˜(t) is an exact solution of (B.34), and ξ˜(0) = ξ(0). It is easy to check that ξ˜(0) (ln m)C, and for 1 6 r 6 p + 1, ≍ (2r+1)C (r) 1+ (2p−1)r C pr (ln m) (r) ∂ ξ˜(0) (ln m)( p+1 ) m− 2(p+1) ∂ ξ(0), t ≍ ≫ mr/2 ≍ t provided that m is large enough. Therefore, ξ˜(t) provides an upper bound for ξ(t). We conclude that for

p 2p−1 C t . m 2(p+1) /(ln m) p+1 ,

it holds that

p 2(2p−1) ˜ A0m C ξ(t) . ξ(t) . p+1 . (ln m) . p 2p−1 C 2p−1 A1m 2(p+1) /(ln m) p+1   This finishes the proof of Proposition B.3.

32 Remark B.4. By the same argument as for (B.34), for any 0 6 r 6 p, we have

ξ(t)p+r ∂(r+1)ξ(t) . , t mr/2

C r C′ which gives us that ξ(t) . (ln m) for t 6 m 2(r+1) /(ln m) . Therefore, for bigger r, we have the a prior estimate ξ(t) . (ln m)C for longer time.

Proof of (2.4) in Theorem 2.3. From the discussion in Section 3 (3.8), we have that each summand in K(r)(x , x , , x ) is of the form t α1 α2 ··· αr s 1 v2j 1(t), v2j (t) h − i, 1 6 s 6 r, vj D0 D1 Dr 2. (B.35) mr/2 1 m ∈ ∪ ∪···∪ − − j=1 Y

If r 6 p+1, Proposition B.3 provides an upper bound on the L norm of those vectors vj (t). So we can bound ∞ p C′ these inner products v2j 1(t), v2j (t) using Proposition B.3. If r 6 p +1, then for 0 6 t 6 m 2(p+1) /(ln m) , it holds that h − i

C vj (t) . (ln m) . k k∞ As a consequence, with high probability with respect to the random initialization,

s s 2C 2sC 2rC 1 v2j 1(t), v2j (t) 1 (ln m) m (ln m) (ln m) h − i . = 6 . mr/2 1 m mr/2 1 m mr/2 1 mr/2 1 − j=1 − j=1 − − Y Y Since K(r)(x , x , , x ) is a linear combination of terms in the form (B.35), the claim (2.4) follows. t α1 α2 ··· αr

C Proof of Corollary 2.4 and 2.5, and Theorem 2.6

Proof of Corollary 2.4. We first derive an upper bound of the kernel K(3)( , , ), using its derivative t · · · n 1 ∂ K(3)(x , x , x )= K(4)(x , x , x , x )(f (t) y ). (C.1) t t α1 α2 α3 −n t α1 α2 α3 β β − β βX=1 p C′ Thanks to (2.4), for 0 6 t 6 m 2(p+1) /(ln m) , it holds that

C (4) (ln m) Kt . . (C.2) k k∞ m (C.2) combining with (B.1) implies an upper bound of the righthand side of (C.1),

n C (3) (4) 1 (ln m) ∂tKt (xα1 , xα2 , xα3 ) 6 max Kt (xα1 , xα2 , xα3 , xβ) fβ(t) yβ . . (C.3) 16β6n | |n | − | m β=1 X

33 (3) C (2.3) gives an upper bound of K0 (xα1 , xα2 , xα3 ) . (ln m) /m, and (C.3) gives an upper bound of the (3) derivative of Kt (xα1 , xα2 , xα3 ). They together implies that with high probability (1 + t)(ln m)C K(3)(x , x , x ) . , (C.4) | t α1 α2 α3 | m for any 1 6 α1, α2, α3 6 n. We recall that

n 1 ∂ K(2)(x , x )= K(3)(x , x , x )(f (t) y ). (C.5) t t α1 α2 −n t α1 α2 β β − β βX=1 Similarly as in (C.3), we can use (C.4) to upper bound the righthand side of (C.5),

n C (2) (3) 1 (1 + t)(ln m) ∂tKt (xα1 , xα2 ) 6 max Kt (xα1 , xα2 , xβ ) fβ(t) yβ . . 16β6n | |n | − | m β=1 X This finishes the proof of Corollary (2.4).

Proof of Corollary 2.5. Corollary 2.4 gives the change rate for each entry of the NTK up to time t 6 p C′ m 2(p+1) /(ln m) , C (2) (1 + t)(ln m) ∂tKt . . (C.6) k k∞ m By integrating both sides of (C.6) from 0 to t, we get an L bound of the change of the NTK, ∞ C (2) (2) t(1 + t)(ln m) Kt K0 . . (C.7) k − k∞ m The L bound in (C.7) can be used to derive a norm bound of the change of the NTK, ∞ C (2) (2) (2) (2) (2) (2) t(1 + t)(ln m) n Kt K0 2 2 6 Kt K0 F 6 n Kt K0 . . k − k → k − k k − k∞ m

The change of the smallest eigenvalue of the NTK is upper bounded by the change of its norm. If c C/2 c (2) (2) 0 6 t 6 λm/n/(ln m) , with some > 0 small enough, the change of the norm Kt K0 2 2 6 λ/2. k − k → Combining with (2.5), we conclude that for 0 6 t 6 c λm/n/(ln m)C/2 p (2) (2) p (2) (2) λmin Kt (xα, xβ) > λmin K0 (xα, xβ) Kt K0 2 2 > λ/2. (C.8) 16α,β6n 16α,β6n −k − k → h i h i From the defining relation (1.6) of the NTK and using (C.8), we have

n n 1 ∂ f (t) y 2 = K(2)(x , x )(f (t) y )(f (t) y ) t k β − βk −n t α β α − α β − β βX=1 α,βX=1 n (C.9) λ 6 f (t) y 2, −2n k β − βk βX=1

34 for 0 6 t 6 c λm/n/(ln m)C/2. Especially, (C.9) implies an exponential decay of the training error,

p n n 2 λt 2 λt f (t) y 6 e− 2n f (0) y . ne− 2n , (C.10) k β − βk k β − βk βX=1 βX=1 for 0 6 t 6 c λm/n/(ln m)C/2. It takes time t (2n/λ) ln(n/ε), for the training error in (C.10) to reach ε. Therefore if ≍ p C c λm/n/(ln m) /2 & (2n/λ) ln(n/ε), (C.11) the dynamic (2.1) finds a global minimum,p the training error reaches ε at time t (n/λ) ln(n/ε). For (C.11) to hold, the neural network needs to be wide ≍

3 n C 2 m > C′ (ln m) ln(n/ε) , λ   with some large constant C′. This finishes the proof of Corollary 2.5.

Proof of Theorem 2.6. We have proven in (B.1) that

n (f (t) y )2 = O(n). (C.12) β − β βX=1 We recall the a priori estimate (2.4) that with high probability with respect to the random initialization, for p C′ 0 6 t 6 m 2(p+1) /(ln m) , it holds that

(ln m)C (r) . Kt r/2 1 . (C.13) k k∞ m − We have better estimates if r is odd. In fact, thanks to the equations for the dynamic of the NTK (2.2),

(r) (r+1) 1 ∂tKt (xα1 , xα2 , , xαr ) 6 max Kt (xα1 , xα2 , , xαr , xβ) fβ(t) yβ ··· 16β6n | ··· |n | − | Xβ C C (C.14) (ln m) 1 (ln m) . f (t) y 2 . . m(r 1)/2 n | β − β| m(r 1)/2 − s β − X Moreover, thanks to (2.3), if r is odd, K(r) = 0 and

(ln m)C (r) . K0 (r 1)/2 . (C.15) k k∞ m − p C′ The estimates (C.14) and (C.15) together imply that if r is odd, for t 6 m 2(p+1) /(ln m)

(1 + t)(ln m)C (r) . Kt (r 1)/2 , (C.16) k k∞ m − which is slightly better than the estimate (C.13).

35 We denote the vector

∆f(t) = (f (t) f˜ (t),f (t) f˜ (t), ,f (t) f˜ (t))⊤, 1 − 1 2 − 2 ··· n − n

At t = 0, we have ∆f(t) = 0. We denote time T the first time that ∆f(t) 2 > √n, i.e. T = inft>0 t : ∆f(t) > √n . Then for t 6 T , we have that k k { k k2 } n n (f (t) y )2 = O(n), (f˜ (t) y )2 = O(n). β − β β − β βX=1 βX=1 Next we study the difference of the original dynamic and the truncated dynamic for t 6 T . We show p ′ C/2 p C that ∆f(t) 2 is much smaller than √n, when t 6 min c λm/n/(ln m) ,m 2( +1) /(ln m) ,T . Asa k k C p C{′ } consequence T > min c λm/n/(ln m) /2,m 2(p+1) /(ln m) . { } p Thanks to (2.4) andp (B.1), we have that

∂ (K(p)(x , x , , x ) K˜ (p)(x , x , , x )) t t α1 α2 ··· αp − t α1 α2 ··· αp n (p+1) 1 1+ t 6 max Kt (xα1 , xα2 , , xαp , xβ) fβ(t) y(t) . . 16β6n | ··· |n | − | mp/2 βX=1 p C′ Thus for t 6 m 2(p+1) /(ln m) we have

(p) ˜ (p) (1 + t)t Kt Kt . . k − k∞ mp/2

By taking difference of (2.2) and (2.7), we have

∂ K(r)(x , x , , x ) K˜ (r)(x , x , , x ) t t α1 α2 ··· αr − t α1 α2 ··· αr n 1  = K(r+1)(x , x , , x , x )(f (t) y ) K˜ (r+1)(x , x , , x , x )(f˜ (t) y ) −n t α1 α2 ··· αr β β − β − t α1 α2 ··· αr β β − β βX=1   n (C.17) 1 = K(r+1)(x , x , , x , x ) K˜ (r+1)(x , x , , x , x ) (f˜ (t) y ) −n t α1 α2 ··· αr β − t α1 α2 ··· αr β β − β βX=1   +K(r+1)(x , x , , x , x )(f (t) f˜ (t)) . t α1 α2 ··· αr β β − β  We estimate the first term on the righthand side of (C.17) as

n 1 K(r+1)(x , x , , x , x ) K˜ (r+1)(x , x , , x , x ) (f˜ (t) y ) n t α1 α2 ··· αr β − t α1 α2 ··· αr β β − β β=1 X   (C.18) n (r+1) ˜ (r+1) 1 ˜ (r+1) ˜ (r+1) 6 Kt Kt fβ(t) yβ . Kt Kt , k − k∞ n | − | k − k∞ βX=1

36 p C′ provided that t 6 min m 2(p+1) /(ln m) ,T . For the second term on the righthand side of (C.17), for p C′ { } t 6 m 2(p+1) /(ln m) , it holds

n 1 K(r+1)(x , x , , x , x )(f (t) f˜ (t)) n t α1 α2 ··· αr β β − β β=1 X n n 1 1 (r+1) ˜ (r+1) ˜ 2 (C.19) 6 Kt fβ(t) fβ(t) 6 Kt fβ(t) fβ(t) k k∞ n | − | k k∞vn | − | βX=1 u βX=1 1 u C 2|rt (r+1) ∆f(t) 2 (ln m) 1+ t ∆f(t) 2 = Kt k k . k k , k k∞ √n m(r 1)/2 √m √n −   where we used (C.13) and (C.16). The estimates (C.18) and (C.19) together imply

∂ (K(r)(x , x , , x ) K˜ (r)(x , x , , x )) t t α1 α2 ··· αr − t α1 α2 ··· αr 1 C 2|r (r+1) ˜ (r+1) (ln m) 1+ t ∆f(t) 2 . Kt Kt + k k . k − k∞ m(r 1)/2 √m √n −   We integrate both sides, and get

1 t t C 2|r ∆f(s) ds (r) ˜ (r) (r+1) ˜ (r+1) (ln m) 1+ t 0 2 Kt Kt . Ks Ks ds + k k , k − k∞ k − k∞ m(r 1)/2 √m √n Z0 −   R p C′ (p) 2(p+1) ˜ for t 6 min m /(ln m) ,T . We notice that in our setting t is much smaller than √m. Using Kt =0 and (C.13){ for r = p 1 and recursively} with r = (p 2), , 2, we have − − ··· 1 p+1 r C 2|r t (r) ˜ (r) (1 + t)t − (ln m) 1+ t Kt Kt . + ∆f(s) 2ds. k − k∞ mp/2 m(r 1)/2√n √m k k −   Z0 And especially,

p 1 C t (2) ˜ (2) (1 + t)t − (1 + t)(ln m) Kt Kt . + ∆f(s) 2ds. (C.20) k − k∞ mp/2 m√n k k Z0

By taking difference of (2.1) and (2.7) we have

n n 1 1 ∂ ∆f (t)= (K˜ (2)(x , x ) K(2)(x , x ))(f˜ (t) y ) K(2)(x , x )∆f (t). (C.21) t α n t α β − t α β β − β − n t α β β βX=1 βX=1 We multiply the vector ∆f(t) on both sides of (C.21)

n 1 1 ∂ ∆f(t) 2 = ∆f(t), (K˜ (2)(x , x ) K(2)(x , x ))(f˜ (t) y ) ∆f(t),K(2)∆f(t) . (C.22) tk k2 nh t α β − t α β β − β i− nh t i βX=1

37 Here we have abused the notation so that n (K˜ (2)(x , x ) K(2)(x , x ))(f˜ (t) y ) in the above β=1 t α β − t α β β − β expression is understood as a vector with the α-component being n (K˜ (2)(x , x ) K(2)(x , x ))(f˜ (t) P β=1 t α β − t α β β − yβ). For the first term on the righthand side of (C.22), we estimate it using (C.20), P n 1 ∆f(t), (K˜ (2)(x , x ) K(2)(x , x ))(f˜ (t) y ) nh t α β − t α β β − β i βX=1 n n ˜ (2) (2) 1 ˜ 6 ∆fα(t) Kt Kt fβ(t) yβ | |k − k∞ n | − | α=1 X βX=1 (C.23) n (1 + t)tp 1 (1 + t)(ln m)C t . − + ∆f(s) ds ∆f (t) mp/2 m√n k k2 | α | 0 α=1  Z  X (1 + t)tp 1 (1 + t)(ln m)C t . − + ∆f(s) ds √n ∆f(t) . mp/2 m√n k k2 k k2  Z0  (2) For the second term on the righthand side of (C.22), we use the fact that [Kt (xα, xβ)]16α,β6n is positive definite. In fact, in (C.8), we have proven that for 0 6 t 6 c λm/n/(ln m)C/2, (2) p λmin Kt (xα, xβ ) > λ/2. 16α,β6n h i Therefore, 1 λ ∆f(t),K(2)∆f(t) 6 ∆f(t) 2. (C.24) −nh t i −2nk k2 By plugging (C.23) and (C.24) into (C.22), and divide both sides by 2 ∆f(t) , we get k k2 (1 + t)tp 1 (1 + t)(ln m)C t λ ∂ ∆f(t) . √n − + ∆f(s) ds ∆f(t) , (C.25) tk k2 mp/2 m√n k k2 − 2nk k  Z0  C p C′ for t 6 min c λm/n/(ln m) /2,m 2(p+1) /(ln m) ,T . To analyze (C.25), we introduce a new quantity, { } p ∆(t)= max ∆f(t) 2. 06s6t k k Then (C.25) implies (1 + t)tp 1 (1 + t)(ln m)C t λ ∂ ∆(t) . max 0, √n − + ∆(s)ds ∆(t) t mp/2 m√n − 2n   Z0   (1 + t)tp 1√n (1 + t)t(ln m)C λ . max 0, − + ∆(t) ∆(t) (C.26) mp/2 m − 2n     (1 + t)tp 1√n (1 + t)t(ln m)C λ . max 0, − + ∆(t) , mp/2 m − 2n     where we used that ∆(t) is monotonic increasing. We can further simplify the righthand side of (C.26), for C p C′ t 6 min c λm/n/(ln m) /2,m 2(p+1) /(ln m) ,T , where c > 0 is small enough, { } p (1 + t)tp 1√n λ ∂ ∆(t) . max 0, − ∆(t) . (C.27) t mp/2 − 4n  

38 We recall that ∆(0) = 0. We can solve (C.27),

s p 1 p 1 λt/4n (1 + s)s − √n λs/4n (1 + t)t − √n ∆(t) . e− e ds . min t,n/λ . mp/2 mp/2 { } Z0

C p C′ It follows that for t 6 min c λm/n/(ln m) /2,m 2(p+1) /(ln m) ,T { } p (1 + t)tp 1√n ∆f(t) . − min t,n/λ , (C.28) k k2 mp/2 { } and

p 1 C (2) ˜ (2) (1 + t)t − (1 + t)t(ln m) n Kt Kt . 1+ min t, . (C.29) k − k∞ mp/2 m λ  n o C p C′ We notice that for t 6 min c λm/n/(ln m) /2,m 2(p+1) /(ln m) ,T , and p > 3, the righthand side of (C.28) { } C p C′ c /2 2(p+1) is much smaller than √n. Fromp the definition of T , it is necessary that T > min λm/n/(ln m) ,m /(ln m) . { C p C′ } c 2(p+1) Thus we can conclude that (C.28) and (C.29) hold for any t 6 min λm/n/p(ln m) ,m /(ln m) . This finishes the proof of Theorem 2.6 { } p

39