FL-NTK: A Neural Tangent Kernel-based Framework for Federated Learning Analysis

Baihe Huang 1 Xiaoxiao Li 2 Zhao Song 3 Xing Yang 4

Abstract from personal data. These personal data are subject to reg- Federated Learning (FL) is an emerging learning ulations such as California Consumer Privacy Act (CCPA) scheme that allows different distributed clients to (Legislature, 2018), Health Insurance Portability and Ac- train deep neural networks together without data countability Act (HIPAA) (Act, 1996), and General Data sharing. Neural networks have become popular Protection Regulation (GDPR) of European Union. Due due to their unprecedented success. To the best of to the data regulations, standard centralized learning tech- our knowledge, the theoretical guarantees of FL niques are not appropriate, and users are much less likely concerning neural networks with explicit forms to share data. Thus, the data are only available on the local and multi-step updates are unexplored. Neverthe- data owners (i.e. edge devices). Federated learning (FL) is less, training analysis of neural networks in FL a new type of learning scheme that avoids centralizing data is non-trivial for two reasons: first, the objective in model training. FL allows local data owners (also known loss function we are optimizing is non-smooth and as clients) to locally train the private model and then send non-convex, and second, we are even not updating the model weights or gradients to the central server. Then in the gradient direction. Existing convergence central server aggregates the shared model parameters to results for -based methods heav- update new global model, and broadcasts the the parameters ily rely on the fact that the gradient direction is of global model to each local client. used for updating. This paper presents a new class Quite different from the centralized training, FL has the of convergence analysis for FL, Federated Learn- following unique properties. First, the training data are ing Neural Tangent Kernel (FL-NTK), which cor- distributed on an astonishing number of devices, and the responds to overparamterized ReLU neural net- connection between the central server and the device is works trained by gradient descent in FL and is slow. Thus, the computational cost is a key factor in FL. inspired by the analysis in Neural Tangent Ker- In communication-efficient FL, local clients are required to nel (NTK). Theoretically, FL-NTK converges to a update model parameters for a few steps locally then send global-optimal solution at a linear rate with prop- their parameters to the server (McMahan et al., 2017). Sec- erly tuned learning parameters. Furthermore, with ond, due to the fact that the data are collected from different proper distributional assumptions, FL-NTK can clients, the local data points can be sampled from differ- also achieve good generalization. ent local distributions. When this happens during training, convergence may not be guaranteed. 1. Introduction The above two unique properties not only bring challenges to algorithm design but also make theoretical analysis much In traditional centralized training, models harder. There have been many efforts developing conver- learn from the data, and data are collected in a database on gence guarantees for FL algorithms based on the assump- the centralized server. In many fields, such as healthcare and tions of convexity and smoothness for the objective func- natural language processing, models are typically learned tions (Yu et al., 2019; Li et al., 2020c; Khaled et al., 2020). 1Peking University, Beijing, China 2The University of Although a recent study (Li et al., 2021) shows theoretical British Colombia, Vancouver, BC, Canada 3Institute for Ad- studies of FL on neural networks, its framework fails to gen- vanced Study, Princeton, NJ, United States 4The University of erate multiple-local updates analysis, which is a key feature Washington, Seattle, WA, United States. Correspondence to: in FL. One of the most conspicuous questions to ask is: Baihe Huang , Xiaoxiao Li , Zhao Song , Can we build a unified and generalizable convergence Xin Yang . analysis framework for ReLU neural networks in FL? Proceedings of the 38 th International Conference on , PMLR 139, 2021. Copyright 2021 by the author(s). In this paper, we give an affirmative answer to this question. FL-NTK: A Neural Tangent Kernel-based Framework for Federated Learning Convergence Analysis

It was recently proved in a series of papers that gradient at initialization. Therefore, we guarantee linear conver- descent converges to global optimum if the neural networks gence results. This technique may further improve our are overparameterized (significantly larger than the size of understanding of many different FL optimization and the datasets). Our work is motivated by Neural Tangent aggregation methods on neural networks. Kernel (NTK) that is originally proposed by (Jacot et al., 2018) and has been extensively studied over the last few Organization. In Section 2 we discuss related work. In years. Section 3 we formulate FL convergence problem. In Sec- tion 4 we state our result. In Section 5 we summarize NTK is defined as the inner product space of pairwise data our technique overviews. In Section 6 we give a proof point gradient (aka Gram matrix). It describes the evolution sketch of our result. In Section 7, we conduct exper- of deep artificial neural networks during their training by iments that affirmatively support our theoretical results. gradient descent. Thus, we propose a novel NTK-based In Section 8 we conclude this paper and discuss future framework for federated learning convergence and general- works. In Appendix A, we list several probability results. ization analysis on ReLU neural networks, so-called Feder- In Appendix B we prove our convergence result of FL- ated Learning Neural Tangent Kernel (FL-NTK). Unlike the NTK. In Appendix C, we prove our generalization result good property of the symmetric Gram matrix in classical of FL-NTK. The full version of this paper is available at NTK, we show the Gram matrix of FL-NTK is asymmetric https://arxiv.org/abs/2105.05001. in Section 3.2. Our techniques address the core question:

How shall we handle the asymmetric Gram matrix in 2. Related Work FL-NTK? Federated Learning Federated learning has emerged as an important paradigm in distributed deep learning. Gener- Our contributions are summarized into Contributions. ally, federated learning can be achieved by two approaches: the following two folds: 1) each party training the model using private data and where only model parameters being transferred and 2) using en- • We proposed a framework to analyze federated learn- cryption techniques to allow safe communications between ing in neural networks. By appealing to recent ad- different parties (Yang et al., 2019a). In this way, the details vances of over-parameterized neural networks, we of the data are not disclosed in between each party. In this pa- prove convergence and generalization results of feder- per, we focus on the first approach, which has been studied ated learning without the assumptions on the convexity in (Dean et al., 2012; Shokri & Shmatikov, 2015; McMahan of objective functions or distribution of data in the et al., 2016; 2017). Federated average (FedAvg) (McMahan convergence analysis. Thus, we make the first step et al., 2017) firstly addressed the communication efficiency toward bridging the gap between the empirical success problem by introducing a global model to aggregate local of federated learning and its theoretical understanding stochastic gradient descent updates. Later, different vari- in the settings of ReLU neural networks. The results ations and adaptations have arisen. This encompasses a theoretically show that given fixed training instances, myriad of possible approaches, including developing better the number of communication rounds increases as the optimization algorithms (Li et al., 2020a; Wang et al., 2020) number of clients increases, which is also supported by and generalizing model to heterogeneous clients under spe- empirical evidence. We show that when the neural net- cial assumptions (Zhao et al., 2018; Kairouz et al., 2021; Li works are sufficiently wide, the training loss across all et al., 2021). clients converges to zero at a linear rate. Furthermore, we also prove a data-dependent generalization bound. Federated learning has been widely used in different fields. Healthcare applications have started to utilize FL for multi- • In federated learning, the update in the global model is center data learning to solve small data, and privacy in data no longer determined by the gradient directions directly. sharing issues (Li et al., 2020b; Rieke et al., 2020; Li et al., Indeed, gradients’ heterogeneity in multiple local steps 2019; Andreux et al., 2020). We have also seen new FL hinders the usage of standard neural tangent kernel algorithms popping up (Wang et al., 2019; Lim et al., 2020; analysis, which is based on the kernel gradient de- Chen et al., 2020a) in mobile edge computing. FL also has scent in the function space for a positive semi-definite promising applications in autonomous driving (Liang et al., kernel. We identify the dynamics of training loss by 2019), financial filed (Yang et al., 2019b), and so on. considering all intermediate states of local steps and establishing the tangent kernel space associated with Convergence of Federated Learning Despite its promis- a general non-symmetric Gram matrix to address this ing benefits, FL comes with new challenges to tackle, es- issue. We prove that this Gram matrix is close to sym- pecially for its convergence analysis under communication- metric at initialization using concentration properties efficiency algorithms and data heterogeneity. The conver- FL-NTK: A Neural Tangent Kernel-based Framework for Federated Learning Convergence Analysis

S gence of the general FL framework on neural networks is R| c| to be the (local) model’s output of its data in the t-th underexplored. A recent work (Li et al., 2021) studies FL global round and k-th local step. For simplicity, we also n convergence on one-layer neural networks. Nevertheless, use y(k)(t) R to denote aggregating all (local) model’s 2 it is limited by the assumption that each client performs outputs in the t-th global round and k-th local step. a single local update epoch. Another line of approaches does not directly work on neural network setting (Li et al., 3.1. Preliminaries 2020c; Khaled et al., 2020; Yu et al., 2019). Instead, they make assumptions on the convexity and smoothness of the In this subsection we introduce Algorithm 1, the brief algo- objective functions, which are not realistic for non-linear rithm of our federated learning (under NTK setting): neural networks. • In the t-th global round, server broadcasts u(t) d m 2 Convergence of deep neural network The convergence R ⇥ to every client. of deep neural network (in non-FL setting) has been exten- • Each client c then starts from w0,c(t)=u(t) and takes sively studied over the last couple of years, Gaussian inputs K (local) steps gradient descent to find wK,c(t). (Zhong et al., 2017b; Li & Yuan, 2017; Zhong et al., 2017a; • Each client sends u (t)=w (t) w (t) to Ge et al., 2018; Zhong et al., 2019; Bakshi et al., 2019; c K,c 0,c Chen et al., 2020b), sufficiently wide neural network (Li & server. Liang, 2018; Du et al., 2019b; Allen-Zhu et al., 2019a;b; • Server computes a new u(t + 1) based on the average Du et al., 2019a; Zhang et al., 2020; Brand et al., 2021; of all uc(t). Specifically, the server updates u(t) by Song & Zhang, 2021; Song et al., 2021), NTK (Jacot et al., the average of all local updates uc(t) and arrives at 2018) and beyond NTK (Bai & Lee, 2019; Li et al., 2020d). For more literature, we refer the readers to textbook (Arora u(t + 1) = u(t)+⌘ u (t)/N. global · c et al., 2020). c [N] X2 3. Problem Formulation as Neural Tangent • We repeat the above four steps by T times. Kernel Setup We define one-hidden layer neural network func- d tion f : R R similar as (Du et al., 2019b; Song & Yang, This section is organized as follows: ! 2019; Brand et al., 2021; Song & Zhang, 2021) • We first introduce the notations used in our analysis. m 1 f(u, x):= a (u>x), • In Section 3.1, we present the basic setup of NTK and pm r r r=1 our federated learning algorithm. X d m d where u R ⇥ and ur R is the r-th column of matrix • In Section 3.2, we give the analysis on the dynamics 2 2 u. of NTK. d m Definition 3.1 (Initialization). We initialize u R ⇥ and m 2 To capture the training dynamic of FL on ReLU neural net- a R as follows: 2 works, we formulate the problem in Neural Tangent Kernel • For each r [m], u is sampled from (0,2I). regime. 2 r N • For each r [m], a is sampled from 1, +1 uni- N 2 r { } Notations We use to denote the number of clients and formly at random (we don’t need to train). use c to denote its index. We use T to denote the number of communication rounds, and use t to denote its index. We define the loss function for j [N], We use K to denote the number of local update steps, and 2 we use k to denote its index. We use u(t) to denote the 1 2 Lj(u, x):= (f(u, xi) yi) , aggregated server model after round t. We use wk,c(t) to 2 i Sj denote c-th client’s model in round t and step k. X2 1 N Let S S S =[n] and S S = . Given n input L(u, x):= Lj(u, x). 1 [ 2 [···[ N i \ j ; N data points and labels (x1,y1), (x2,y2), , (xn,yn) in j=1 d { ··· } X R R, the data of each client c is (xi,yi):i Sc . Let ⇥ { 2 } @f(u,x) d (z) = max z,0 denote the ReLU activation. We can compute the gradient R (of function f), { } @ur 2 Sc For each client c [N], we use yc R| | to denote the @f(u, x) 1 2 2 (k) = arx1 . ground truth with regard to its data, and denote yc (t) ur>x 0 2 @ur pm FL-NTK: A Neural Tangent Kernel-based Framework for Federated Learning Convergence Analysis

Algorithm 1 Training Neural Network with FedAvg under NTK setting. d m 1: ur(0) (0,Id) for r [m]. .u R ⇥ ⇠N 2 2 2: for t =1,...,T do 3: for c =1,...,N do 4: /*Client receives weights from server*/ d m 5: w0,c(t) u(t) .w0,c(t),u(t) R ⇥ 2 6: /*Client runs K local steps gradient descent*/ 7: for k =1,...,K do 8: for i Sc do (2k) 1 m (k) Sc 9: yc (t)i ar(wk,c,r(t)>xi) .yc (t) R| | pm r=1 2 10: end for P 11: for r =1 m do ! 12: for i Sc do 2 1 @f(wk,c(t),xi,a) 1 d 13: Ji,: ar0(wk,c,r(t)>xi)x> .Ji,: = R ⇥ pm i @wk,c,r(t) 2 14: end for (k) Sc d 15: grad J >(yc yc (t)) .J R| |⇥ r 2 16: wk,c,r(t) wk 1,c,r(t) ⌘local gradr · 17: end for 18: end for 19: /*Client sends weights to server*/ 20: u w (t) u(t) c k,c 21: end for 22: /*Server aggregates the weights*/ 1 d m 23: u N c [N] uc . u R ⇥ 2 2 d m 24: u(t + 1) u(t)+⌘globalu.u(t + 1) R ⇥ P 2 25: end for

@L (u) d We can compute the gradient c R (of function L), average @ur 2 n @L (u) 1 u(t)= uc(t)/N, c = (f(u, x ) y )a x 1 . i i r i ur,c> xi 0 c [N] @u pm 2 r i=1 X X where u (t)=w w for all c [N]. c K,c 0,c 2 We formalize the problem as minimizing the sum of loss functions over all clients: Global steps in total Then global model simply add u(t) to its parameters. min L(u). u d M 2R ⇥ u(t + 1) u(t)+⌘ u(t). global ·

Local update In each local step, we update wk,c by gra- 3.2. NTK Analysis dient descent. n n The neural tangent kernel H1 R ⇥ , introduced in (Ja- 2 cot et al., 2018), is given by @Lc(wk,c) wk+1,c wk,c ⌘local . Hi,j1 := E xi>xj1w x 0,w x 0 · @wk,c w (0,I) > i > j ⇠N ⇥ ⇤ Note that n At round t, let y(t)=(y1(t),y2(t), ,yn(t)) R be ··· 2 @Lc(wk,c) 1 the prediction vector where yi(t) R is defined as = (f(wk,c,xi) yi)arxi1w> xi 0. 2 @wk,c,r pm k,c,r i Sc X2 yi(t)=f(u(t),xi).

n Global aggregation In each global communication round Recall that we denote labels y =(y1, ,yn) R , ··· 2 we aggregate all local updates of clients by taking a simple yc = yi i S , predictions y(t)=(y(t)1, ,y(t)n) and { } 2 c ··· FL-NTK: A Neural Tangent Kernel-based Framework for Federated Learning Convergence Analysis

2 yc(t)= y(t)i i Sc , we can then rewrite y y(t) 2 as where { } 2 k k follows: (k) qk,c,j,r := (yc (t)j yj)xi>xj1w (t) x ,u (t) x 0, 2 k,c,r > j r > i y y(t + 1) 2 k k and v is sufficiently small which we will show later. = y y(t) (y(t + 1) y(t)) 2 2 k k2 2 Now, we extend the NTK analysis to FL. We start with = y y(t) 2(y y(t))>(y(t + 1) y(t)) k k2 defining the Gram matrix for f as follows. + y(t + 1) y(t) 2. (1) k k2 Definition 3.2. For any t [0,T],k [K],c [N], we 2n n 2 2 define matrix H(t, k, c) R ⇥ as follows Now we focus on y(t + 1) y(t), notice for each i [n] 2 2 m y (t + 1) y (t) 1 i i H(t, k, c) = x>x 1 , i,j i j ur>xi 0,wk,c,r(t)>xj 0 m m 1 r=1 = a ((u (t + 1)>x ) (u>(t)x)) X pm r r i r 1 r=1 H(t, k, c)i,j? = xi>xj1u x 0,w (t) x 0. X m r> i k,c,r > j m r Q 1 X2 i = a (u (t)+⌘ u (t))>x pm r r global r i r=1 ✓ X This Gram matrix is crucial for the analysis of error dynam- ics. When t =0and the width m approaches infinity, the (u>(t)x) (2) r H ◆ matrix becomes the NTK, and with infinite width, neural where networks just behave like kernel methods with respect to the a ⌘ NTK (Arora et al., 2019b; Lee et al., 2020). It turns out that u (t):= r local r N pm in the finite width case (Li & Liang, 2018; Allen-Zhu et al., c [N] k [K] j Sc X2 X2 X2 2019a;b; Du et al., 2019b;a; Song & Yang, 2019; Oymak (k) & Soltanolkotabi, 2020; Huang & Yau, 2020; Chen & Xu, (yj yc (t)j)xj1wk,c,r(t) xj 0. > 2020; Zhang et al., 2020; Brand et al., 2021; Song & Zhang, In order to further analyze Eq (2), we separate neurons 2021), the spectral property of the gram matrix also governs into two sets. One set contains neurons with the activation convergence guarantees for neural networks. pattern changing over time and another set contains neurons We can then decompose 2(y y(t))>(y(t + 1) y(t)) with activation pattern holding the same. Specifically for into each i [n], we define the set Q [m] of neurons whose 2 i ⇢ activation pattern is certified to hold the same throughout 2(y y(t))>(y(t + 1) y(t)) the algorithm = 2(y y(t))>(v1 + v2) d Qi := r [m]: w R s.t. w wr(0) 2 R, 2⌘global⌘local { 2 8 2 k k  = pi,k,c,j 1 = 1 , N wr (0)>xi 0 w>xi 0 i [n] k [K] c [N] j Sc } X2 X2 X2 X2 and use Q to denote its complement. Then yi(t + 1) 2 (y y (t))v (4) i i i 2,i yi(t)=v1,i + v2,i where i [n] X2 1 where v = a (u (t)+⌘ u (t))>x 1,i pm r r global r i r Qi ✓ p := (y y (t)) (y y(k)(t) ) X2 i,k,c,j i i · j c j (u (t)>x ) , (H(t, k, c)i,j H(t, k, c)? ). r i · i,j ◆ 1 v2,i = ar (ur(t)+⌘globalur(t))>xi Now it remains to analyze Eq (1) and Eq (4). Our analysis pm leverages several key observations in the classical Neural r Qi ✓ X2 Tangent Kernel theory throughout the learning process: (ur(t)>xi) . • Weights change lazily, namely ◆ (3)

u(t + 1) u(t) 2 O(1/n). The benefit of this procedure is that v1 can be written in k k  closed form ⌘ ⌘ • Activation patterns remain roughly the same, namely v = global local q , 1,i Nm k,c,j,r k [K],c [N] j Sc r Qi H(t, k, c)? F O(1), 2 X2 X2 X2 k k  FL-NTK: A Neural Tangent Kernel-based Framework for Federated Learning Convergence Analysis

and of H(0). For N clients, for any ✏, let

v O( y y(t) ). N k 2k2  k k2 T = O log(1/✏) , ⌘local⌘globalK · • Error controls model updates, namely ⇣ ⌘ there is an algorithm (FL-NTK) runs in T global steps and y(t + 1) y(t) O( y y(t) ). each client runs K local steps with choosing k k2  k k2 Based on the above observations, we show that the dynamics ⌘local = O and ⌘global = O(1) Kn2 of federated learning is dominated in the following way ⇣ ⌘ and outputs weight U with probability at least 1 such y y(t + 1) 2 k k2 that the training loss L(u, x) satisfies y y(t) 2 ⇡k k2 N (k) 1 2 2 (y y(t))>H(t, k)(y y (t)), L(u, x)= (f(u, x ) y ) ✏. 2N i i  k [K] i=1 X2 X n n where the Gram matrix H(t, k) R ⇥ comes from 2 We note that our theoretical framework is very powerful. combining the S columns of H(t, k, c) for all c [N]. c 2 With additional assumptions on the training data distribution Since S S S =[n] and S S = , ev- 1 [ 2 [···[ N i \ j ; and test data distribution, we can also show an upper bound ery j [n] belongs to one unique S for some c and 2 c for the generalization error of federated learning in neural H(t, k) = H(t, k, c) ,j S . i,j i,j 2 c networks. We first introduce a distributional assumption, There are two difficulties lies in the analysis of these dynam- which is standard in the literature (e.g, see (Arora et al., ics. First, unlike the symmetric Gram matrix in the standard 2019a; Song & Yang, 2019)). NTK theory for centralized training, our FL framework’s Definition 4.2 (Non-degenerate Data Distribution, Defini- Gram matrix is asymmetric. Secondly, model update in each tion 5.1 in (Arora et al., 2019a)). A distribution over D global round is influenced by all intermediate model states Rd R is (,, n)-non-degenerate, if with probability at ⇥ of all the clients. least 1 , for n i.i.d. samples (x ,y )n chosen from , i i i=1 D (H(0)) >0. To address these difficulties, we bring in two new techniques min to facilitate our understanding of the learning dynamics. Our result on generalization bound is stated in the following theorem. • First, we generalize Theorem 4.2 in (Du et al., 2019b) to non-symmetric Gram matrices. We show in Theorem 4.3 (Informal, see Appendix C for details). Fix failure probability (0, 1). Suppose the training data Lemma B.11 that with good initialization H(t, k) is 2 S = (x ,y ) n are i.i.d samples from a (,/3,n)-non- close to the original Gram matrix H(0), so that model { i i }i=1 degenerate distribution , and could benefit from a linear learning rate determined by D the smallest eigenvalue of H(0). • O( poly(log n, log(1/))/n), • Secondly, we leverage concentration properties at ini-  · 2 1 tialization to bound the difference between the errors • m ⌦( poly(n, log m, log(1/), )), · in local steps and the errors in the global step. Specifi- 1 cally, we can show that y y(k)(t) y y(t) for all • T ⌦((⌘local⌘globalK) N log(n/)) . ⇡ k [K]. 2 d m m We initialize u R ⇥ and a R as Definition 3.1. 2 2 Consider any loss function ` : R R [0, 1] that is 1- 4. Our Results ⇥ ! Lipschitz in its first argument. Then with probability at least We first present the main result on the convergence of feder- 1 the following event happens: after T global steps, the ated learning in neural networks by the following theorem. generalization loss Theorem 4.1 (Informal version of Theorem B.3). Let L (f):= E [`(f(u, x),y)] D (x,y) 4 4 ⇠D m =⌦( n log(n/)), is upper bounded as we iid initialize ur(0), ar as Definition 3.1 where =1. Let 1 1/2 denote min(H(0)). Let  denote the condition number L (f) (2y>(H1) y/n) + O( log(n/())/n). D  p FL-NTK: A Neural Tangent Kernel-based Framework for Federated Learning Convergence Analysis

2 5. Techniques Overview = y y(t) 2(y y(t))>(y(t + 1) y(t)) k k2 + y(t + 1) y(t) 2 (5) In our algorithm, when K =1, i.e., we only perform one k k2 local step per global step, essentially we are performing To this end, we need to investigate the change of prediction n gradient descent on all the data points with step size in consecutive rounds, which is described in Eq. (2). For ⌘ ⌘ /N global local . As the norm of the gradient is proportional the sake of simplicity, we introduce the notation to 1/pm, when the neural network is sufficiently wide, we can control the norm of the gradient. Then by the update rule zi,r := (ur(t)+⌘globalur(t))>xi (ur(t)>xi), of gradient descent, we hence upper bound the movement of the first layer weight u. By anti-concentration of the then we have normal distribution, this implies that for each input x, the 1 m activation status of most of the ReLU gates remains the y(t + 1)i y(t)i = arzi,r pm r=1 same as initialized, which enables us to apply the standard X convergence analysis of gradient descent on convex and 1 = arzi,r + v2,i, smooth functions. Finally, we can handle the effect of ReLU pm r Qi gates whose activation status have changed by carefully X2 choosing the step size. where v2,i is introduced in Eq. (3). (k) However, the analysis becomes much more complicated For client c [N] let yc (t)j (j Sc) be defined when K 2, where the movement of u is no longer de- (k) 2 2 by yc (t)j = f(wk,c(t),xj). By the gradient-averaging termined by the gradient directly. Nevertheless, on each scheme described in Algorithm 1, ur(t), the change in the client, we are still performing gradient descent for K local global weights is steps. So we can handle the movement of the local weight ar ⌘local (k) w. The major technical bulk of this work is then proving (yj yc (t)j)xj1w (t) x 0. N pm k,c,r > j that the training error shrinks when we set the global weight c [N] k [K] j Sc X2 X2 X2 movement as the average of the local weight. Our argument Therefore, we can calculate 1 a z is inspired by that of (Du et al., 2019b) but is much more pm r Qi r i,r 2 involved. 1 P a z pm r i,r r Q 6. Proof Sketch X2 i 1 = a (u (t)+⌘ u (t))>x In this section we sketch our proof of Theorem 4.1 and pm r r global r i r Qi ✓ Theorem 4.3. The detailed proof is deferred to Appendix B. X2

In order to prove the linear convergence rate in Theorem (ur(t)>xi) 4.1, it is sufficient to show that the training loss shrinks in ⌘ ⌘ ◆ each round, or formally for each ⌧ =0, 1, , = global local ··· mN 2 k [K] c [N] j Sc y(⌧ + 1) y 2 X2 X2 X2 k k (k) ⌘global⌘localK (yj yc (t)j)xi>xj 1u xi 0,wk,c,r(t) xj 0 2 · r> > 1 y(⌧) y 2. r Q  2N ·k k X2 i ⇣ ⌘ ⌘global⌘local We prove Eq. (5) by induction. Assume that we have proved = N for ⌧ t 1 and we want to prove Eq. (5) for ⌧ = t.We k [K] c [N] j Sc  X2 X2 X2 first show that the movement of the weight u is bounded (k) (yj y (t)j)(H(t, k, c)i,j H(t, k, c)? ). under the induction hypothesis. c i,j Lemma 6.1 (Movement of global weight, informal version Further, we write 2(y y(t))>(y(t+1) y(t)) as follows: of Lemma B.12). For any r [m], 2 2(y y(t))>(y(t + 1) y(t)) pn y y(0) 2 u (t) u (0) = O k k . = 2(y y(t))>(v1 + v2) k r r k2 pm 2⌘ ⌘ ⇣ ⌘ = global local The detailed proof can be found in Appendix B.6. N i [n] k [K] c [N] j Sc X2 X2 X2 X2 (k) We then turn to the proof of Eq. (5) by decomposing the (yi yi(t))(yj yc (t)j)(H(t, k, c)i,j H(t, k, c)i,j? ) loss in t +1-th global round 2 (yi yi(t))v2,i. 2 y y(t + 1) 2 i [n] k k X2 FL-NTK: A Neural Tangent Kernel-based Framework for Federated Learning Convergence Analysis

7. Experiment We examine our theoretical results To summarize, we can decompose the loss as Models and datasets on a benchmark dataset - Cifar10 in FL study. We perform 2 2 y y(t + 1) 2 = y y(t) 2 + C1 + C2 + C3 + C4. 10 class classification tasks using ResNet56 (He et al., 2016). k k k k For fair convergence comparison, we fixed the total number where of samples n. Based on our main result Theorem 4.1, we 2⌘global⌘local show the convergence with respect to the number of client C = 1 N N. To clearly evaluate the effects on N, we set all the clients i [n] k [K] c [N] j Sc X2 X2 X2 X2 to be activated (sampling all the clients) at each aggregation. (y y (t))(y y(k)(t) )H(t, k, c) , i i j c j i,j We examine the settings of both non-iid and iid clients: 2⌘global⌘local C2 =+ N Data distribution is homogeneous in all the clients. i [n] k [K] c [N] j Sc iid X2 X2 X2 X2 (k) Specifically, the label distribution over 10 classes is a (y y (t))(y y (t) )H(t, k, c)? , i i j c j i,j uniform distribution. C = 2 (y y (t))v , 3 i i 2,i i [n] non-iid Data distribution is heterogeneous in all the clients. For X2 non-iid splits, we utilize the Dirichlet distribution as C =+ y(t + 1) y(t) 2. 4 k k2 in (Yurochkin et al., 2019; Wang et al., 2020; He et al., 2020a). First, each of these datasets is heterogeneously Let R = maxr [m] ur(t) ur(0) 2 be the maximal move- 2 k k divided into J batches. The heterogeneous partition ment of global weights. Note that by Lemma 6.1, R can be allocates p Dir (↵) proportion of the instances of z ⇠ J made arbitrarily small as long as the width m is sufficiently label z to batch j. Then one label is sampled based on large. Next, we bound C1,C2,C3,C4, by arguing that they these vectors for each device, and an image is sampled are bounded from above if R is small and the learning rate is without replacement based on the label. For all the properly chosen. The detailed proof is deferred to Appendix classes, we repeat this process until all data points are B.3. assigned to devices. Lemma 6.2 (Bounding of C1,C2,C3,C4, informal version of Claim B.4, Claim B.5, Claim B.6 and Claim B.7) As- . Setup The experiment is conducted on one Ti2080 Nvidia sume that GPU. We use SGD the the local optimizer with a learning rate of 0.03 to train the neural network.1. We set batch size R = O(/n) • , as 128. We set the local update epoch K =1, 2, 5 and 102. 2 • ⌘local = O(1/(Kn )), We set Dirichlet distribution parameter for non-iid data as ↵ =0.5. For all local update epoch settings, we set the • ⌘global = O(1). client and sever communication round as 200. There are 50,000 training instances in Cifar10. We vary the number Then with high probability we have of clients N = 10, 50, 100 and record the training loss over 2 communication rounds. The implementation is based on C1 ⌘global⌘localK y y(t) /N,  k k2 FedML (He et al., 2020b). C + ⌘ ⌘ K y y(t) 2/(40N), 2  global local k k2 2 C3 + ⌘global⌘localK y y(t) 2/(40N), N  k k Impact of Our theory suggests that given fixed training 2 instances (fixed n) a smaller N requires less communication C4 + ⌘global⌘localK y y(t) 2/(40N).  k k rounds to converge to a given ✏. In other words, a smaller N Combining the above lemma, we arrive to may accelerate the convergence of the global model. Intu- itively, a large N on a fixed amount of data means a heavier y y(t + 1) 2 communication burden. Figure 1 shows the training loss k k2 ⌘ ⌘ K curve of different choices of N. We empirically observe 1 global local y y(t) 2,  2N ·k k2 consistent results over different K that a smaller N con- ⇣ ⌘ verges faster given a fix number of total training samples, which completes the proof of Eq. (5). Finally Theorem 4.1 which is affirmative to our theoretical results. follows from 1 T It is difficult to run real NTK experiment or full-batch gradient 2 ⌘global⌘localK y y(T ) 2 1 descent (GD) due to memory contrains. k k  2N 2Local update step equals to epoch in GD. But the number of ✏⇣. ⌘ steps of one epoch in SGD is equal to the number of mini-batches.  FL-NTK: A Neural Tangent Kernel-based Framework for Federated Learning Convergence Analysis

(a) Local update epoch K =1 (b) Local update epoch K =2

(c) Local update epoch K =5 (d) Local update epoch K =10

Figure 1: Training loss vs. communication rounds when number of clients N = 10, 50, 100 with iid and non-iid setting using mini-batch SGD optimizer.

8. Discussion It will be interesting to extend the current FL-NTK frame- work for multi-layer cases. Existing NTK results of two- Overall, we provide the first comprehensive proof of con- layer neural networks (NNs) have shown to be generalized vergence of gradient descent and generalization bound for to multi-layer NNs with various structures such as RNN, over-parameterized ReLU neural network in federated learn- CNN, ResNet, etc. (Allen-Zhu et al., 2019a;b; Du et al., ing. We consider the training dynamic with respect to local 2019a). The key techniques for analyzing FL on wide neu- update steps K and number of clients N. ral networks are addressed in our work. Our result can be Different from most of existing theoretical analysis work generalized to multi-layer NNs by controlling the perturba- on FL, FL-NTK prevails over the ability to exploit neural tion propagation through layers using well-developed tools network parameters. There is a great potential for FL-NTK (rank-one perturbation analysis, randomness decomposition, to understand the behavior of different neural network archi- and extended McDiarmid’s Inequality). We hope our results tectures, i.e., how affects convergence and techniques will provide insights for further study of in FL, like FedBN (Li et al., 2021). Different from FedBN, distributed learning and other optimization algorithms. whose analysis is limited to K =1, we provide the first general framework for FL convergence analysis by consid- Acknowledgements ering K 2. The extension is non-trivial, as parameter update does not follow the gradient direction due to the het- Baihe Huang is supported by the Elite Undergraduate Train- erogeneity of local data. To tackle this issue, we establish ing Program of School of Mathematical Sciences at Peking the training trajectory by considering all intermediate states University, and work done while interning at Princeton Uni- and establishing an asymmetric Gram matrix related to local versity and Institute for Advanced Study (advised by Zhao gradient aggregations. We show that with quartic network Song). Xin Yang is supported by NSF grant CCF-2006359. width, federated learning can converge to a global-optimal This project is partially done while Xin was a Ph.D. student at a linear rate. We also provide a data-dependent generaliza- at University of Washington. The authors would like to tion bound of over-parameterized neural networks trained thank the ICML reviewers for their valuable comments. with federated learning. The full version of this paper is available at https:// arxiv.org/abs/2105.05001. FL-NTK: A Neural Tangent Kernel-based Framework for Federated Learning Convergence Analysis

References Chen, M., Yang, Z., Saad, W., Yin, C., Poor, H. V., and Cui, S. A joint learning and communications framework Act, A. Health insurance portability and accountability act for federated learning over wireless networks. IEEE of 1996. Public law, 104:191, 1996. Transactions on Wireless Communications, 2020a. Allen-Zhu, Z., Li, Y., and Song, Z. A convergence theory for Chen, S., Klivans, A. R., and Meka, R. Learning deep deep learning via over-parameterization. In International relu networks is fixed-parameter tractable. arXiv preprint Conference on Machine Learning (ICML), pp. 242–252, arXiv:2009.13512, 2020b. 2019a. Dean, J., Corrado, G., Monga, R., Chen, K., Devin, M., Allen-Zhu, Z., Li, Y., and Song, Z. On the convergence Mao, M., Ranzato, M., Senior, A., Tucker, P., Yang, K., rate of training recurrent neural networks. In Advances et al. Large scale distributed deep networks. In Advances in Neural Information Processing Systems (NeurIPS), in neural information processing systems (NeurIPS), pp. 2019b. 1223–1231, 2012. Andreux, M., du Terrail, J. O., Beguier, C., and Tramel, Du, S., Lee, J., Li, H., Wang, L., and Zhai, X. Gradient E. W. Siloed federated learning for multi-centric descent finds global minima of deep neural networks. In histopathology datasets. In Domain Adaptation and Rep- International Conference on Machine Learning (ICML), resentation Transfer, and Distributed and Collaborative pp. 1675–1685. PMLR, 2019a. Learning, pp. 129–139. Springer, 2020. Du, S. S., Zhai, X., Poczos, B., and Singh, A. Gradi- Arora, R., Arora, S., Bruna, J., Cohen, N., Du, S., Ge, R., ent descent provably optimizes over-parameterized neu- Gunasekar, S., Jin, C., Lee, J., Ma, T., Neyshabur, B., and ral networks. In International Conference on Learn- Song, Z. Theory of deep learning. In manuscript, 2020. ing Representations (ICLR). https://arxiv.org/ pdf/1810.02054, 2019b. Arora, S., Du, S., Hu, W., Li, Z., and Wang, R. Fine-grained analysis of optimization and generalization for overpa- Ge, R., Lee, J. D., and Ma, T. Learning one-hidden-layer rameterized two-layer neural networks. In International neural networks with landscape design. In International Conference on Machine Learning (ICML), pp. 322–332, Conference on Learning Representations (ICLR), 2018. 2019a. He, C., Annavaram, M., and Avestimehr, S. Group knowl- Arora, S., Du, S. S., Hu, W., Li, Z., Salakhutdinov, R. R., edge transfer: Federated learning of large cnns at the and Wang, R. On exact computation with an infinitely edge. Advances in Neural Information Processing Sys- wide neural net. In Advances in Neural Information Pro- tems (NeurIPS), 2020a. cessing Systems (NeurIPS), pp. 8141–8150, 2019b. He, C., Li, S., So, J., Zhang, M., Wang, H., Wang, X., Vepakomma, P., Singh, A., Qiu, H., Shen, L., et al. Fedml: Bai, Y. and Lee, J. D. Beyond linearization: On quadratic A research library and benchmark for federated machine and higher-order approximation of wide neural networks. learning. NeurIPS Federated Learning workshop, 2020b. In International Conference on Learning Representations (ICLR), 2019. He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learn- ing for image recognition. In Proceedings of the IEEE Bakshi, A., Jayaram, R., and Woodruff, D. P. Learning two conference on and layer rectified neural networks in polynomial time. In (CVPR), pp. 770–778, 2016. Conference on Learning Theory (COLT), pp. 195–268. PMLR, 2019. Huang, J. and Yau, H.-T. Dynamics of deep neural networks and neural tangent hierarchy. In International Conference Bernstein, S. On a modification of chebyshev’s inequality on Machine Learning (ICML), pp. 4542–4551, 2020. and of the error formula of laplace. Ann. Sci. Inst. Sav. Ukraine, Sect. Math, 1(4):38–49, 1924. Jacot, A., Gabriel, F., and Hongler, C. Neural tangent ker- nel: Convergence and generalization in neural networks. Brand, J. v. d., Peng, B., Song, Z., and Weinstein, O. Train- In Advances in neural information processing systems ing (overparametrized) neural networks in near-linear (NeurIPS), pp. 8571–8580, 2018. time. In Innovations in Theoretical Computer Science (ITCS), 2021. Kairouz, P., McMahan, H. B., Avent, B., Bellet, A., Bennis, M., Bhagoji, A. N., Bonawitz, K., Charles, Z., Cormode, Chen, L. and Xu, S. Deep neural tangent kernel and G., Cummings, R., et al. Advances and open problems in laplace kernel have the same rkhs. arXiv preprint federated learning. Foundations and Trends in Machine arXiv:2009.10683, 2020. Learning, 2021. FL-NTK: A Neural Tangent Kernel-based Framework for Federated Learning Convergence Analysis

Khaled, A., Mishchenko, K., and Richtárik, P. Tighter the- Lim, W. Y. B., Luong, N. C., Hoang, D. T., Jiao, Y., Liang, ory for local SGD on indentical and heterogeneous data. Y.-C., Yang, Q., Niyato, D., and Miao, C. Federated In Proceedings of Artificial Intelligence and Statistics learning in mobile edge networks: A comprehensive sur- (AISTATS), 2020. vey. IEEE Communications Surveys & Tutorials, 22(3): 2031–2063, 2020. Lee, J. D., Shen, R., Song, Z., Wang, M., and Yu, Z. Gen- eralized leverage score sampling for neural networks. McMahan, B., Moore, E., Ramage, D., Hampson, S., and In Advances in Neural Information Processing Systems y Arcas, B. A. Communication-efficient learning of deep (NeurIPS), 2020. networks from decentralized data. In Proceedings of Artificial Intelligence and Statistics (AISTATS), pp. 1273– Legislature, C. S. California consumer privacy act (ccpa). 1282. PMLR, 2017. https://oag.ca.gov/privacy/ccpa, 2018. McMahan, H., Moore, E., Ramage, D., and Agüera y Arcas, Li, T., Sahu, A. K., Zaheer, M., Sanjabi, M., Talwalkar, B. Federated learning of deep networks using model A., and Smith, V. Federated optimization in heteroge- averaging. 2016. neous networks. In Conference on Machine Learning and Oymak, S. and Soltanolkotabi, M. Towards moder- Systems (MLSys), 2020a. ate overparameterization: global convergence guar- Li, W., Milletarì, F., Xu, D., Rieke, N., Hancox, J., Zhu, antees for training shallow neural networks. In https://arxiv.org/pdf/1902. W., Baust, M., Cheng, Y., Ourselin, S., Cardoso, M. J., arXiv preprint. 04674.pdf et al. Privacy-preserving federated brain tumour segmen- , 2020. tation. In International Workshop on Machine Learning Rieke, N., Hancox, J., Li, W., Milletari, F., Roth, H. R., in Medical Imaging, pp. 133–141. Springer, 2019. Albarqouni, S., Bakas, S., Galtier, M. N., Landman, B. A., Maier-Hein, K., et al. The future of digital health with Li, X., Gu, Y., Dvornek, N., Staib, L., Ventola, P., and federated learning. NPJ digital medicine, 3(1):1–7, 2020. Duncan, J. S. Multi-site fmri analysis using privacy- preserving federated learning and domain adaptation: Shokri, R. and Shmatikov, V. Privacy-preserving deep learn- Abide results. Medical Image Analysis, 2020b. ing. In Proceedings of the 22nd ACM SIGSAC conference on computer and communications security (CCS), pp. Li, X., Huang, K., Yang, W., Wang, S., and Zhang, Z. On 1310–1321. ACM, 2015. the convergence of fedavg on non-iid data. International Conference on Learning Representations (ICLR), 2020c. Song, Z. and Yang, X. Quadratic suffices for over- parametrization via matrix chernoff bound. arXiv preprint Li, X., Jiang, M., Zhang, X., Kamp, M., and Dou, Q. arXiv:1906.03593, 2019. Fedbn: Federated learning on non-iid features via lo- cal batch normalization. In International Conference on Song, Z. and Zhang, R. Hybrid quantum-classical imple- Learning Representations (ICLR), 2021. URL https: mentation of training over-parameterized neural networks. //openreview.net/forum?id=6YEQUn0QICG. In manuscript, 2021. Song, Z., Yang, S., and Zhang, R. Does preprocessing Li, Y. and Liang, Y. Learning overparameterized neural help training over-parameterized neural networks. In networks via stochastic gradient descent on structured manuscript, 2021. data. In Advances in Neural Information Processing Systems (NeurIPS), pp. 8157–8166, 2018. Wang, H., Yurochkin, M., Sun, Y., Papailiopoulos, D., and Khazaeni, Y. Federated learning with matched averaging. Li, Y. and Yuan, Y. Convergence analysis of two-layer arXiv preprint arXiv:2002.06440, 2020. neural networks with ReLU activation. In Advances in neural information processing systems (NIPS), pp. 597– Wang, S., Tuor, T., Salonidis, T., Leung, K. K., Makaya, 607, 2017. C., He, T., and Chan, K. Adaptive federated learning in resource constrained edge computing systems. IEEE Li, Y., Ma, T., and Zhang, H. R. Learning over-parametrized Journal on Selected Areas in Communications, 37(6): two-layer neural networks beyond ntk. In Conference on 1205–1221, 2019. Learning Theory (COLT), pp. 2613–2682, 2020d. Yang, Q., Liu, Y., Chen, T., and Tong, Y. Federated machine Liang, X., Liu, Y., Chen, T., Liu, M., and Yang, Q. Federated learning: Concept and applications. ACM Transactions transfer reinforcement learning for autonomous driving. on Intelligent Systems and Technology (TIST), 10(2):12, arXiv preprint arXiv:1910.06001, 2019. 2019a. FL-NTK: A Neural Tangent Kernel-based Framework for Federated Learning Convergence Analysis

Yang, W., Zhang, Y., Ye, K., Li, L., and Xu, C.-Z. Ffd: a federated learning based method for credit card fraud detection. In International Conference on Big Data, pp. 18–32. Springer, 2019b. Yu, H., Yang, S., and Zhu, S. Parallel restarted SGD with faster convergence and less communication: Demystify- ing why model averaging works for deep learning. In Proceedings of the AAAI Conference on Artificial Intelli- gence (AAAI), pp. 5693–5700, 2019. Yurochkin, M., Agarwal, M., Ghosh, S., Greenewald, K., Hoang, N., and Khazaeni, Y. Bayesian nonparametric fed- erated learning of neural networks. In International Con- ference on Machine Learning (ICML), pp. 7252–7261, 2019. Zhang, Y., Plevrakis, O., Du, S. S., Li, X., Song, Z., and Arora, S. Over-parameterized adversarial train- ing: An analysis overcoming the curse of dimensionality. In Advances in Neural Information Processing Systems (NeurIPS), 2020. Zhao, Y., Li, M., Lai, L., Suda, N., Civin, D., and Chandra, V. Federated learning with non-iid data. arXiv preprint arXiv:1806.00582, 2018.

Zhong, K., Song, Z., and Dhillon, I. S. Learning non- overlapping convolutional neural networks with multiple kernels. arXiv preprint arXiv:1711.03440, 2017a. Zhong, K., Song, Z., Jain, P., Bartlett, P. L., and Dhillon, I. S. Recovery guarantees for one-hidden-layer neural net- works. In International Conference on Machine Learning (ICML), 2017b. Zhong, K., Song, Z., Jain, P., and Dhillon, I. S. Provable non-linear inductive matrix completion. Advances in Neural Information Processing Systems (NeurIPS), 32: 11439–11449, 2019.