Fetchsgd: Communication-Efficient Federated Learning with Sketching
Total Page:16
File Type:pdf, Size:1020Kb
FetchSGD: Communication-Efficient Federated Learning with Sketching 1 1 2 3 1 2 Daniel Rothchild∗ Ashwinee Panda∗ Enayat Ullah Nikita Ivkin Ion Stoica Vladimir Braverman Joseph Gonzalez 1 Raman Arora 2 Abstract 2018). Even without these concerns, training machine learn- ing models in the cloud can be expensive, and an effective Existing approaches to federated learning suf- way to train the same models on the edge has the potential fer from a communication bottleneck as well as to eliminate this expense. convergence issues due to sparse client participa- tion. In this paper we introduce a novel algorithm, When training machine learning models in the federated called FetchSGD, to overcome these challenges. setting, participating clients do not send their local data to FetchSGD compresses model updates using a a central server; instead, a central aggregator coordinates Count Sketch, and then takes advantage of the an optimization procedure among the clients. At each it- mergeability of sketches to combine model up- eration of this procedure, clients compute gradient-based dates from many workers. A key insight in the updates to the current model using their local data, and they design of FetchSGD is that, because the Count communicate only these updates to a central aggregator. Sketch is linear, momentum and error accumu- A number of challenges arise when training models in the lation can both be carried out within the sketch. federated setting. Active areas of research in federated learn- This allows the algorithm to move momentum ing include solving systems challenges, such as handling and error accumulation from clients to the central stragglers and unreliable network connections (Bonawitz aggregator, overcoming the challenges of sparse et al., 2016; Wang et al., 2019), tolerating adversaries (Bag- client participation while still achieving high com- dasaryan et al., 2018; Bhagoji et al., 2018), and ensuring pression rates and good convergence. We prove privacy of user data (Geyer et al., 2017; Hardy et al., 2017). that FetchSGD has favorable convergence guar- In this work we address a different challenge, namely that of antees, and we demonstrate its empirical effec- training high-quality models under the constraints imposed tiveness by training two residual networks and a by the federated setting. transformer model. There are three main constraints unique to the federated set- ting that make training high-quality models difficult. First, 1. Introduction communication-efficiency is a necessity when training on Federated learning has recently emerged as an important set- the edge (Li et al., 2018), since clients typically connect to the central aggregator over slow connections ( 1Mbps) ting for training machine learning models. In the federated ∼ setting, training data is distributed across a large number (Lee et al., 2010). Second, clients must be stateless, since of edge devices, such as consumer smartphones, personal it is often the case that no client participates more than once computers, or smart home devices. These devices have during all of training (Kairouz et al., 2019). Third, the data data that is useful for training a variety of models – for text collected across clients is typically not independent and prediction, speech modeling, facial recognition, document identically distributed. For example, when training a next- identification, and other tasks (Shi et al., 2016; Brisimi et al., word prediction model on the typing data of smartphone 2018; Leroy et al., 2019; Tomlinson et al., 2009). However, users, clients located in geographically distinct regions gen- data privacy, liability, or regulatory concerns may make it erate data from different distributions, but enough common- difficult to move this data to the cloud for training (EU, ality exists between the distributions that we may still want to train a single model (Hard et al., 2018; Yang et al., 2018). *Equal contribution 1University of California, Berke- ley, California, USA 2Johns Hopkins University, Baltimore, In this paper, we propose a new optimization algorithm for Maryland 3Amazon. Correspondence to: Daniel Rothchild federated learning, called FetchSGD, that can train high- <[email protected]>. quality models under all three of these constraints. The crux of the algorithm is simple: at each round, clients compute th Proceedings of the 37 International Conference on Machine a gradient based on their local data, then compress the gra- Learning, Online, PMLR 119, 2020. Copyright 2020 by the au- thor(s). dient using a data structure called a Count Sketch before FetchSGD: Communication-Efficient Federated Learning with Sketching sending it to the central aggregator. The aggregator main- often referred to as local/parallel SGD, has been studied tains momentum and error accumulation Count Sketches, since the early days of distributed model training in the data and the weight update applied at each round is extracted center (Dean et al., 2012), and is referred to as FedAvg from the error accumulation sketch. See Figure1 for an when applied to federated learning (McMahan et al., 2016). overview of FetchSGD. FedAvg has been successfully deployed in a number of domains (Hard et al., 2018; Li et al., 2019), and is the most FetchSGD requires no local state on the clients, and we commonly used optimization algorithm in the federated set- prove that it is communication efficient, and that it con- ting (Yang et al., 2018). In FedAvg, every participating verges in the non-i.i.d. setting for L-smooth non-convex client first downloads and trains the global model on their functions at rates T 1/2 and T 1/3 respectively O − O − local dataset for a number of epochs using SGD. The clients under two alternative assumptions – the first opaque and the upload the difference between their initial and final model second more intuitive. Furthermore, even without maintain- to the parameter server, which averages the local updates ing any local state, FetchSGD can carry out momentum – weighted according to the magnitude of the corresponding a technique that is essential for attaining high accuracy in local dataset. the non-federated setting – as if on local gradients before FedAvg compression (Sutskever et al., 2013). Lastly, due to prop- One major advantage of is that it requires no lo- erties of the Count Sketch, FetchSGD scales seamlessly cal state, which is necessary for the common case where FedAvg to small local datasets, an important regime for federated clients participate only once in training. is also learning, since user interaction with online services tends communication-efficient in that it can reduce the total num- to follow a power law distribution, meaning that most users ber of bytes transferred during training while achieving the will have relatively little data to contribute (Muchnik et al., same overall performance. However, from an individual 2013). client’s perspective, there is no communication savings if the client participates in training only once. Achieving high We empirically validate our method with two image recog- accuracy on a task often requires using a large model, but nition tasks and one language modeling task. Using models clients’ network connections may be too slow or unreliable with between 6 and 125 million parameters, we train on to transmit such a large amount of data at once (Yang et al., non-i.i.d. datasets that range in size from 50,000 – 800,000 2010). examples. Another disadvantage of FedAvg is that taking many local steps can lead to degraded convergence on non-i.i.d. data. 2. Related Work Intuitively, taking many local steps of gradient descent on local data that is not representative of the overall data dis- Broadly speaking, there are two optimization strategies that tribution will lead to local over-fitting, which will hinder have been proposed to address the constraints of federated convergence (Karimireddy et al., 2019a). When training a learning: Federated Averaging (FedAvg) and extensions model on non-i.i.d. local datasets, the goal is to minimize thereof, and gradient compression methods. We explore the average test error across clients. If clients are chosen these two strategies in detail in Sections 2.1 and 2.2, but as a randomly, SGD naturally has convergence guarantees on brief summary, FedAvg does not require local state, but it non-i.i.d. data, since the average test error is an expectation also does not reduce communication from the standpoint of over which clients participate. However, although FedAvg a client that participates once, and it struggles with non-i.i.d. has convergence guarantees for the i.i.d. setting (Wang data and small local datasets because it takes many local and Joshi, 2018), these guarantees do not apply directly gradient steps. Gradient compression methods, on the other to the non-i.i.d. setting as they do with SGD. Zhao et al. hand, can achieve high communication efficiency. However, (2018) show that FedAvg, using local steps, converges it has been shown both theoretically and empirically that K as ( ) on non-i.i.d. data for strongly convex smooth these methods must maintain error accumulation vectors on K/T functions,O with additional assumptions. In other words, con- the clients in order to achieve high accuracy. This is ineffec- vergence on non-i.i.d. data could slow down as much as tive in federated learning, since clients typically participate proportionally to the number of local steps taken. in optimization only once, so the accumulated error has no chance to be re-introduced (Karimireddy et al., 2019b). Variants of FedAvg have been proposed to improve its per- formance on non-i.i.d. data. Sahu et al.(2018) propose 2.1. FedAvg constraining the local gradient update steps in FedAvg by penalizing the L2 distance between local models and the cur- FedAvg reduces the total number of bytes transferred dur- rent global model. Under the assumption that every client’s ing training by carrying out multiple steps of stochastic loss is minimized wherever the overall loss function is mini- gradient descent (SGD) locally before sending the aggre- mized, they recover the convergence rate of SGD.