Optimizing Federated Learning on Non-IID Data with Reinforcement Learning

Hao Wang1, Zakhary Kaplan1, Di Niu2 and Baochun Li1 1University of Toronto, {haowang, bli}@ece.utoronto.ca, [email protected] 2University of Alberta, [email protected]

Abstract—The widespread deployment of However, unlike in a server-based environment, distributed applications in ubiquitous environments has sparked interests in machine learning on mobile devices is faced with a few exploiting the vast amount of data stored on mobile devices. To fundamental challenges, such as the limited connectivity of preserve data privacy, Federated Learning has been proposed to learn a shared model by performing distributed training locally wireless networks, unstable availability of mobile devices, and on participating devices and aggregating the local models into the non-IID distributions of local datasets which are hard to a global one. However, due to the limited network connectivity characterize statistically [1]–[4]. Due to these reasons, it is not of mobile devices, it is not practical for federated learning to as practical to perform machine learning on all participating perform model updates and aggregation on all participating devices simultaneously in a synchronous or semi-synchronous devices in parallel. Besides, data samples across all devices are usually not independent and identically distributed (IID), posing manner as is in a parameter-server cluster. Therefore, as additional challenges to the convergence and speed of federated a common practice in federated learning, only a subset of learning. devices are randomly selected to participate in each round In this paper, we propose FAVO R, an experience-driven control of model training to avoid long-tailed waiting times due to framework that intelligently chooses the client devices to partic- unstable network conditions and straggler devices [1, 5]. ipate in each round of federated learning to counterbalance the Since the completion time of federated learning is dom- bias introduced by non-IID data and to speed up convergence. Through both empirical and mathematical analysis, we observe inated by the communication time, Konecnˇ y` et al. [6, 7] an implicit connection between the distribution of training data proposed structured and sketched updates to decrease the on a device and the model weights trained based on those time it takes to complete a communication round. McMa- data, which enables us to profile the data distribution on that han et al. [1] presented the Federated Averaging (FEDAVG) device based on its uploaded model weights. We then propose algorithm, which aims to reduce the number communication a mechanism based on deep Q-learning that learns to select a subset of devices in each communication round to maximize a rounds required by averaging model weights from client de- reward that encourages the increase of validation accuracy and vices instead of applying traditional updates. penalizes the use of more communication rounds. With extensive FEDAVG has also been deployed in a large-scale system [4] experiments performed in PyTorch, we show that the number to test its effectiveness. of communication rounds required in federated learning can However, existing federated learning methods have not be reduced by up to 49% on the MNIST dataset, 23% on FashionMNIST, and 42% on CIFAR-10, as compared to the solved the statistical challenges posed by heterogeneous local Federated Averaging algorithm. datasets. Since different users have different device usage patterns, the data samples and labels located on any in- I. INTRODUCTION dividual device may follow a different distribution, which cannot represent the global data distribution. It has been As the primary computing resource for billions of users, recently pointed out that the performance of federated learning, mobile devices are constantly generating massive volumes especially FEDAVG, may significantly degrade in the presence of data, such as photos, voices, and keystrokes, which are of non-IID data, in terms of the model accuracy and the of great value for training machine learning models. How- communication rounds required for convergence [8]–[10]. In ever, due to privacy concerns, collecting private data from particular, FEDAVG randomly selects a subset of devices in mobile devices to the cloud for centralized model training each round and averages their local model weights to update is not always possible. To efficiently utilize end-user data, the global model. The randomly selected local datasets may Federated Learning (FL) has emerged as a new paradigm of not reflect the true data distribution in a global view, which distributed machine learning that orchestrates model training inevitably incurs biases to global model updates. Furthermore, across mobile devices [1]. With federated learning, locally the models locally trained on non-IID data can be significantly trained models are communicated to a server for aggregation, different from each other. Aggregating these divergent models without collecting any raw data from users. Federated learning can slow down convergence and substantially reduce the model has enabled joint model training over privacy-sensitive data accuracy [8]. in a wide range of applications, including natural language In this paper, we present the design and implementation processing, , and speech recognition. of FAVO R, a control framework that aims to improve the This work is supported by a research contract with Huawei Technologies performance of federated learning through intelligent device Co. Ltd. and a NSERC Discovery Research Program. selection. Based on reinforcement learning, FAVO R aims to accelerate and stabilize the federated learning process by denote the prediction function, where = z C z = S { | i=1 i learning to actively select the best subset of devices in each 1,zi > 0, i [C] . That is, the vector valued function f 8 2 } P communication round that can counterbalance the bias intro- yields a probability vector z for each sample x, where fi duced by non-IID data. predicts the probability that the sample belongs to the ith class. Since privacy concerns have ruled out any possibility of Let the vector w denote model weights. For classification, accessing the raw data on each device, it is impossible to the commonly used training loss is cross entropy, defined as profile the data distribution on each device directly. However, C we observe that there is an implicit connection between the `(w):=Ex,y p 1y=i log fi(x, w) distribution of the training samples on a device and the ⇠  i=1 model weights trained based on those samples. Therefore, C X our main intuition is to use the local model weights and the = p(y = i)Ex y=i[log fi(x, w)], | shared global model as states to judiciously select devices i=1 that may contribute to global model improvement. We propose X The learning problem is to solve the following optimization a reinforcement learning agent based on a Deep Q-Network problem: (DQN), which is trained through a Double DQN for increased efficiency and robustness. We carefully design the reward C signal in order to aggressively improve the global model minimizew p(y = i)Ex y=i[log fi(x, w)]. (1) | i=1 accuracy per round as well as to reduce the total number of X communication rounds required. We also propose a practical In federated learning, suppose there are N client devices scheme that compresses model weights to reduce the high in total. The kth device has m(k) data samples following dimensionality of the state space, especially for big models, the data distribution p(k), which is a joint distribution of the e.g., deep neural network models. samples x,y on this device. In each round t, K devices { } We have implemented FAVO R in a federated learning simu- are selected (at random in the original FL), each of which 1 lator developed from scratch using PyTorch , and evaluated it downloads the current global model weights wt 1 from the under a variety of federated learning tasks. Our experimental server and performs the following stochastic gradient descent results on the MNIST, FashionMNIST, and CIFAR-10 datasets (SGD) training locally: have shown that FAVO R can reduce the required number of (k) communication rounds in federated learning by up to 49% on wt = wt 1 ⌘ `(wt 1) (2) r the MNIST, by up to 23% on FashionMNIST, and by up to C (k) 42% on CIFAR-10, as compared to the FEDAVG algorithm. = wt 1 ⌘ p (y = i) wEx y=i[log fi(x, wt 1)], r | i=1 II. BACKGROUND AND MOTIVATION X where ⌘ is the learning rate. Note that in such notation with In this section, we briefly introduce how federated learning expectations, w(k) can be a result one or multiple epochs of works in general and show how non-IID data poses challenges t local SGD training [1]. to existing federated learning algorithms. We demonstrate how (k) Once wt is obtained, each participating device k reports to properly select client devices at each round to improve the (k) performance of federated learning on non-IID data. a model weight difference t to the FL server [1], which is defined as A. Federated Learning (k) (k) (k) t := wt wt 1. Federated learning (FL) trains a shared global model by (k) iteratively aggregating model updates from multiple client de- Devices can also upload their local models wt directly, (k) vices, which may have slow and unstable network connections. but transferring t is more amenable to compression and Initially, eligible client devices first check-in with a remote communication reduction [4]. After the FL server has collected server. The remote server then proceeds federated learning updates from all K participating devices in round t, it performs synchronously in rounds. In each round, the server randomly the federated averaging (FEDAVG) algorithm to update the selects a subset of available client devices to participate in global model: training. The selected devices first download the latest global K (k) (k) K (k) model from the server, train the model on their local datasets, t = k=1m t / k=1m , and report their respective model updates to the server for Pwt wt 1 +Pt. aggregation. We formally introduce FL in the context of a C-class B. The Challenges of Non-IID Data Distribution classification problem, which is defined over a compact feature FEDAVG has been shown to work well approximating the space and a label space =[C], where [C]=1, ,C. X Y ··· model trained on centrally collected data, given that the data Let (x,y) denote a particular labeled sample. Let f : X!S and label distributions on different devices are IID [1]. How- 1Our implementation is available as an open-source GitHub repository, ever, in reality, data owned by each device are typically non- located at https://github.com/iqua/flsim. IID, i.e., p(k) are different on different devices, due to different rounds to achieve the target accuracy in the IID setting2. However, FEDAVG takes 163 rounds to achieve the 99% target 98 accuracy on non-IID data, which more than triples the number of communication rounds. 96 To reduce the required communication rounds under the non-IID setting, we then tuned the device selection strategy. 94 While in FL, it is impossible to peek into the data distribution Accuracy(%) FedAvg-IID of each device, we observe that there is an implicit connection FedAvg-non-IID between the data distribution on a device and the model 92 K-Center-non-IID weights trained on that device. Therefore, we may profile each 020406080100120140160 device by analyzing the local model weights after the first Communication Round (#) round. Fig. 1: Training a CNN model on non-IID MNIST data. We provide some rationale for this observation. According to (2), given the initial global model weights winit, the local weights on device k after one epoch of SGD can be user preferences and usage patterns. When data distributions represented by are non-IID, FEDAVG is unstable and may even diverge [8]. C (k) (k) This is due to the inconsistency between the locally per- w1 = winit ⌘ p (y = i) wEx y=i[log fi(x, winit)]. r | formed SGD algorithm, which aims to minimize the loss i=1 X m(k) value on local samples on each device and the global The weight divergence between device k and device k0 can K (k) objective of minimizing the overall loss on all m (k0) (k) k=1 be measured as w w , where k, k0 [K]. Then, k 1 1 k 2 data samples. As we keep fitting models on different devices by using a similar bounding technique adopted in [8], we can to heterogeneous local data, the divergence amongP the weights (k) derive the following bound for weight divergence in the first w of these local models will be accumulated and eventually round: degrades the performance of learning [8, 10], leading to more communication rounds before training converges. Reducing w(k0) w(k) k 1 1 k6 the number of communication rounds in federated learning C is crucially essential for mobile devices, which have a limited ⌘g (w ) p(k0)(y = i) p(k)(y = i) , max init k k computation capacity and communication bandwidth. i=1 X C We use an experiment to demonstrate that non-IID data may where gmax(w) := maxi=1 wEx y=i[log fi(x, w)] . kr | k slow down the convergence of federated learning if devices This implies that even if local models are trained with the same are randomly selected in each round for model aggregation. initial weights winit, there is an implicit connection between Instead of random selection, we show that selecting devices the discrepancy of data distributions on device k and k0, which with a clustering algorithm can help to even out data distri- is the term C p(k0)(y = i) p(k)(y = i) , and their i=1 k k bution and speed up convergence. weight divergence. P In our experiment, we train a two-layer CNN model with Based on the sights above, we first let each of the 100 PyTorch on the MNIST dataset (containing 60,000 samples) devices download the same initial global weights (randomly until the model achieves a 99% test accuracy. We consider generated) and perform one epoch of SGD based on local (k) 100 devices in total, each running on a different thread, and data to get w1 , k =1,...,100. We then apply the K-Center (1) (100) we train the CNN model under the IID and non-IID settings, clustering algorithm [11] onto w ,...,w to cluster { 1 1 } respectively. For the IID setting, the 60,000 samples are the 100 devices into 10 groups. In each round, we randomly randomly distributed among 100 devices, such that each device select one device from each group to participate in FL training, owns 600 samples. For the non-IID setting, each device still which still results in 10 simultaneously participating devices owns 600 samples, yet 80% of which come from a dominant per round. As shown in Fig. 1, the K-Center algorithm takes class and the remaining 20% belong to other classes. For 131 rounds to reach the target accuracy under the non-IID example, a “0”-dominated device has 480 data samples with setting, which is 32 rounds less than the vanilla FEDAVG the label “0”, while the remaining 120 data samples have labels algorithm. This example implies that it is possible to improve evenly distributed among “1” to “9”. In each round, ten devices the performance of federated learning by carefully selecting are selected to participate in federated learning. Note that FL the set of participating devices in each round, especially under selects only a subset of all devices at each round to balance the non-IID data setting. computation complexity and convergence rate, and to avoid C. Deep Reinforcement Learning (DRL) excessively long waiting times for model aggregation [1, 4], which is also discussed in Sec. IV-D. Reinforcement Learning (RL) is the learning process of an agent that acts in corresponding to the environment to FEDAVG randomly selects ten devices to participate in each round of training. As shown in Fig. 1, FEDAVG takes 47 2The same setting costs FEDAVG 50 rounds on TensorFlow in [1]. maximize its rewards. The recent success of RL [12, 13] objective is to train the DRL agent to converge to the target comes from the capability that RL agents learn from the accuracy for federated learning as quickly as possible. interaction with a dynamic environment. Specifically, at each In the proposed framework, the agent does not need to time step t, the RL agent observes states st and performs collect or check any data samples from mobile devices— an action at. The state st of the environment then transits to only model weights are communicated—thus preserving the st+1, and the agent will receive reward rt. The state transitions same level of privacy as the original FL does. It solely and rewards follow a Markov Decision Process (MDP) [14], relies on model weight information to determine which device which is a discrete time stochastic control process, denoted by may improve the global model the most, since there is an a sequence s1, a1, r1, s2, ..., rt 1, st, at, .... The objective implicit connection between the data distribution on a device of RL is to maximize the expectation of the cumulative and its local model weights obtained by performing SGD T t 1 discounted return R = r , where (0, 1] is a on those data. In fact, in Sec. III-C we will show that even t=1 t 2 factor discounting future rewards. after dimension reduction, the divergence between local model The RL agent seeks toP learn a cheat sheet that points out weights is still obvious and contains information to guide the actions leading to the maximum cumulative return under a device selection. particular state. Value-based RL algorithms formally uses an A. The Agent based on Deep Q-Network action-value function Q(st,a) to estimate an expected return starting from state st: Suppose there is a federated learning job on N available devices with a target accuracy ⌦. In each round, using a Deep 1 k 1 Q-Network (DQN) [15], K devices are selected to participate Q⇡(st,a):=E⇡[ rt+k 1 st,a] | k=1 in the training. Considering limited available traces from fed- X erated learning jobs, DQN can be more efficiently trained and = Es ,a[rt + Q⇡(st+1,a) st,at], t+1 | can reuse data more effectively than policy gradient methods where ⇡ is the policy mapping from states to probabilities of and actor-critic methods. selecting possible actions. The optimal action-value function State: Let the state of round t be represented by a vector (1) (N) Q⇤(st,a) is the cheat sheet sought by the RL agent, which st =(wt, wt ,...,wt ), where wt denotes the weights of (1) (N) is defined as the maximum expectation of the cumulative the global model after round t, and wt ,...,wt are model discounted return starting from st: weights of the N devices, respectively. The agent collocates with the FL server and maintains a Q⇤(st,a):=Est+1 [rt + max Q⇤(st+1,a) st,a]. (3) (k) (k) a | list of model weights w k [N] ; a particular w is { | 2 } Then, we could apply function approximation techniques to updated in round t only if device k is selected for training (k) learn a parameterized value function Q(st,a; ✓t) approxi- and the resulting t is received by the FL server. Therefore, mating the optimal value function Q⇤(st,a). The one-step there is no extra communication overhead introduced for the lookahead rt + maxa Q(st+1,a; ✓t) becomes the target that devices. Q(st,a; ✓t) learns to be. Typically, a deep neural network The resulting state space can be huge, e.g., a CNN model (DNN) is used to represent the function approximator [15]. can contain millions of weights. It is challenging to train a The RL learning problem becomes minimizing the MSE loss DQN with such a large state space. In practice, we propose between the target and the approximator, which is defined as: to apply an effective and lightweight dimension reduction technique on the state space, i.e., on model weights, which 2 `t(✓t):=(rt + max Q(st+1,a; ✓t) Q(st,a; ✓t)) (4) will be described in detail in Sec. III-C. a Action: At the beginning of each round t, the agent needs III. DRL FOR CLIENT SELECTION to decide to select which subset of K devices from the N We formulate device selection for federated learning as a devices. This selection would have resulted in a large action N deep reinforcement learning (DRL) problem. Based on the space of size K , which complicates RL training. We propose formulation, we present a federated learning workflow driven a trick to keep the action space small while still leveraging the by DRL in Sec. III-B. We then introduce the challenges of intelligent control provided by the DRL agent. dealing with high-dimensional models in federated learning. In particular, the agent is trained by selecting only one out Finally, we describe training the DRL agent with the double of N devices to participate in FL per round based on DQN, Deep Q-learning Network (DDQN) [16]. while in testing and application the agent will sample a batch The per-round FL process can be modeled as a Markov of top-K clients to participate in FL. That is, the DQN agent Decision Process (MDP), with state s represented by the learns an approximator of the optimal action-value function global model weights and the model weights of each client Q⇤(st,a) through a neural network, which estimates the action device in each round. Given a state, the DRL agent takes that maximizes the expected return starting from st. The action an action a to select a subset of devices that perform local space is thus reduced to 1, 2,...,N , where a = i means { } training and update the global model. Then a reward signal r that device i is selected to participate in FL. is observed, which is a function of the test accuracy achieved Once DQN has been trained to approximate Q⇤(s,a), by the global model so far on a held-out validation set. The during testing, in round t the DQN agent will compute Q⇤(st,a) a [N] for all N actions. Each action-value { | 2 } time indicates the maximum expected return that the agent can get … Step 1 by selecting a particular action a at state st. Then we select the K devices, each corresponding to a different action a, that Check-in lead to the top-K values of Q⇤(st,a).

wAAAB/3icbVDLSsNAFJ3UV62vqODGzWARXJWkCrosuHFZwT6gDWEymbRDJzNhZqKUmIW/4saFIm79DXf+jZM2C229MMzhnHu5554gYVRpx/m2Kiura+sb1c3a1vbO7p69f9BVIpWYdLBgQvYDpAijnHQ01Yz0E0lQHDDSCybXhd67J1JRwe/0NCFejEacRhQjbSjfPhoGgoVqGpsve8j9jHKq85pv152GMyu4DNwS1EFZbd/+GoYCpzHhGjOk1MB1Eu1lSGqKGclrw1SRBOEJGpGBgRzFRHnZzH8OTw0TwkhI87iGM/b3RIZiVVg0nTHSY7WoFeR/2iDV0ZVnTkpSTTieL4pSBrWARRgwpJJgzaYGICyp8QrxGEmEtYmsCMFdPHkZdJsN97zRvL2ot5plHFVwDE7AGXDBJWiBG9AGHYDBI3gGr+DNerJerHfrY95ascqZQ/CnrM8f1xmWkg== … init Initialize Reward: We set the reward observed at the end of each (k) (!t ⌦) (k) w k [N] round t to be rt =⌅ 1, t =1,...,T, where !t is the w AAACDHicbVC7TsMwFHXKq5RXgZHFokIqS5UUJBgrsTChItGHlITKcZzWquNEtgOqQj6AhV9hYQAhVj6Ajb/BaTNAy5UsH51zru69x4sZlco0v43S0vLK6lp5vbKxubW9U93d68ooEZh0cMQi0feQJIxy0lFUMdKPBUGhx0jPG1/keu+OCEkjfqMmMXFDNOQ0oBgpTQ2qNSd1vIj5chLqL73PBtZtWh8fZw9jh3JoX7lOpl1mw5wWXARWAWqgqPag+uX4EU5CwhVmSErbMmPlpkgoihnJKk4iSYzwGA2JrSFHIZFuOj0mg0ea8WEQCf24glP2d0eKQplvq50hUiM5r+Xkf5qdqODcTSmPE0U4ng0KEgZVBPNkoE8FwYpNNEBYUL0rxCMkEFY6v4oOwZo/eRF0mw3rpNG8Pq21mkUcZXAADkEdWOAMtMAlaIMOwOARPINX8GY8GS/Gu/Exs5aMomcf/Cnj8wcWOZuR 1 AAACAnicbVDLSgMxFM34rPU16krcBItQN2WmCrosuHFZwT6gHYdMJm1DM8mQZJQyDG78FTcuFHHrV7jzb8y0s9DWAyGHc+7l3nuCmFGlHefbWlpeWV1bL22UN7e2d3btvf22EonEpIUFE7IbIEUY5aSlqWakG0uCooCRTjC+yv3OPZGKCn6rJzHxIjTkdEAx0kby7cN+IFioJpH50ofMT93sLq2OT7Oyb1ecmjMFXCRuQSqgQNO3v/qhwElEuMYMKdVznVh7KZKaYkaycj9RJEZ4jIakZyhHEVFeOj0hgydGCeFASPO4hlP1d0eKIpVvaSojpEdq3svF/7xeogeXXkp5nGjC8WzQIGFQC5jnAUMqCdZsYgjCkppdIR4hibA2qeUhuPMnL5J2veae1eo355VGvYijBI7AMagCF1yABrgGTdACGDyCZ/AK3qwn68V6tz5mpUtW0XMA/sD6/AE57Zc/ 1 { | 2 } testing accuracy achieved by the global model on the held-out validation set after round t, ⌦ is the target accuracy, and ⌅ … Step 2 is a positive constant that ensures that rt grows exponentially (k) with the testing accuracy ! . We have r ( 1, 0], since Q(w ,k; ✓) k [N] Step 3 DRL Agent t t AAACKHicbVDLSsNAFJ3Ud31VXboZLEIFKUkVFFxYcONKLNgqNLFMptN2yGQSZm6UEvM5bvwVNyKKdOuXOG2z8HVgmMM593LvPX4suAbbHlmFmdm5+YXFpeLyyuraemljs6WjRFHWpJGI1I1PNBNcsiZwEOwmVoyEvmDXfnA29q/vmNI8klcwjJkXkr7kPU4JGKlTOnXTRsX1I9HVw9B86X3HuU0rwV6W7ePgBH+3XBgwINneAw5cLnH7wnOzTqlsV+0J8F/i5KSMclx2Sq9uN6JJyCRQQbRuO3YMXkoUcCpYVnQTzWJCA9JnbUMlCZn20smhGd41Shf3ImWeBDxRv3ekJNTjZU1lSGCgf3tj8T+vnUDv2Eu5jBNgkk4H9RKBIcLj1HCXK0ZBDA0hVHGzK6YDoggFk23RhOD8PvkvadWqzkG11jgs12t5HItoG+2gCnLQEaqjc3SJmoiiR/SM3tC79WS9WB/WaFpasPKeLfQD1ucX+MGmeg== 2 { 1 | 2 } 0 6 !t 6 ⌦ 6 1. The federated learning stops when !t =⌦,

at which point rt reaches its maximum value of 0. AAAB+XicbVDLSsNAFL2pr1pfUZdugkVwVZIq6LLgxmUF+4A2hMlk0g6dzISZSaWE/okbF4q49U/c+TdO2iy09cAwh3PuZc6cMGVUadf9tiobm1vbO9Xd2t7+weGRfXzSVSKTmHSwYEL2Q6QIo5x0NNWM9FNJUBIy0gsnd4XfmxKpqOCPepYSP0EjTmOKkTZSYNvDULBIzRJz5U/zwAvsuttwF3DWiVeSOpRoB/bXMBI4SwjXmCGlBp6baj9HUlPMyLw2zBRJEZ6gERkYylFClJ8vks+dC6NETiykOVw7C/X3Ro4SVYQzkwnSY7XqFeJ/3iDT8a2fU55mmnC8fCjOmKOFU9TgRFQSrNnMEIQlNVkdPEYSYW3KqpkSvNUvr5Nus+FdNZoP1/VWs6yjCmdwDpfgwQ204B7a0AEMU3iGV3izcuvFerc+lqMVq9w5hT+wPn8A9smT0Q== w1 Selection Step 4 The DQN agent is trained to maximize the expectation of … Reporting the cumulative discounted reward given by (k) w (k) AAAB/3icbVDNS8MwHE39nPOrKnjxUhzCvIy2CnocePE4wX3AVkuapltYmpQkVUbtwX/FiwdFvPpvePO/Md160M0HIY/3fj/y8oKEEqls+9tYWl5ZXVuvbFQ3t7Z3ds29/Y7kqUC4jTjlohdAiSlhuK2IoriXCAzjgOJuML4q/O49FpJwdqsmCfZiOGQkIggqLfnm4SDgNJSTWF/ZQ36X1cenue/6Zs1u2FNYi8QpSQ2UaPnm1yDkKI0xU4hCKfuOnSgvg0IRRHFeHaQSJxCN4RD3NWUwxtLLpvlz60QroRVxoQ9T1lT9vZHBWBYR9WQM1UjOe4X4n9dPVXTpZYQlqcIMzR6KUmopbhVlWCERGCk60QQiQXRWC42ggEjpyqq6BGf+y4uk4zacs4Z7c15rumUdFXAEjkEdOOACNME1aIE2QOARPINX8GY8GS/Gu/ExG10yyp0D8AfG5w8oLZYg 2 w k [K] T T {AAACDHicbVDLSgMxFM3UV62vqks3wSLUTZmpgi4LbgQ3FewDZsaSSdM2NJMMSUYp43yAG3/FjQtF3PoB7vwbM+0stPVCyOGcc7n3niBiVGnb/rYKS8srq2vF9dLG5tb2Tnl3r61ELDFpYcGE7AZIEUY5aWmqGelGkqAwYKQTjC8yvXNHpKKC3+hJRPwQDTkdUIy0oXrlipd4gWB9NQnNl9ynvfptUh0fpw9jj3LoXvlealx2zZ4WXARODiogr2av/OX1BY5DwjVmSCnXsSPtJ0hqihlJS16sSITwGA2JayBHIVF+Mj0mhUeG6cOBkOZxDafs744EhSrb1jhDpEdqXsvI/zQ31oNzP6E8ijXheDZoEDOoBcySgX0qCdZsYgDCkppdIR4hibA2+ZVMCM78yYugXa85J7X69WmlUc/jKIIDcAiqwAFnoAEuQRO0AAaP4Bm8gjfryXqx3q2PmbVg5T374E9Znz8TOZuP 2 | 2 } t 1 t 1 (! ⌦) R = r = (⌅ t 1), … Step 5 t t=1 t=1 X X Q(w(k),k; ✓) k [N] Step 3 DRL Agent where (0, 1] is a factor discounting future rewards. AAACKHicbVDLSsNAFJ3Ud31VXboZLEIFKUkUFFxYcONKLNgqNLFMptN2yGQSZm6UEvM5bvwVNyKKdOuXOH0sfB0Y5nDOvdx7T5AIrsG2h1ZhZnZufmFxqbi8srq2XtrYbOo4VZQ1aCxidRMQzQSXrAEcBLtJFCNRINh1EJ6N/Os7pjSP5RUMEuZHpCd5l1MCRmqXTr2sXvGCWHT0IDJfdt92b7NKuJfn+zg8wd8tD/oMSL73gEOPS9y68L28XSrbVXsM/Jc4U1JGU1y2S69eJ6ZpxCRQQbRuOXYCfkYUcCpYXvRSzRJCQ9JjLUMliZj2s/GhOd41Sgd3Y2WeBDxWv3dkJNKjZU1lRKCvf3sj8T+vlUL32M+4TFJgkk4GdVOBIcaj1HCHK0ZBDAwhVHGzK6Z9oggFk23RhOD8PvkvabpV56Dq1g/LNXcaxyLaRjuoghx0hGroHF2iBqLoET2jN/RuPVkv1oc1nJQWrGnPFvoB6/ML+nGmew== 2 2 { | 2 } We now explain the motivations behind the two terms (!t ⌦) (!t ⌦) ⌅ and 1 in r . The first term, ⌅ , incentivizes the AAAB+XicbVDLSsNAFL2pr1pfUZdugkVwVZIq6LLgxmUF+4A2hMlk0g6dzISZSaWE/okbF4q49U/c+TdO2iy09cAwh3PuZc6cMGVUadf9tiobm1vbO9Xd2t7+weGRfXzSVSKTmHSwYEL2Q6QIo5x0NNWM9FNJUBIy0gsnd4XfmxKpqOCPepYSP0EjTmOKkTZSYNvDULBIzRJz5U/zoBnYdbfhLuCsE68kdSjRDuyvYSRwlhCuMUNKDTw31X6OpKaYkXltmCmSIjxBIzIwlKOEKD9fJJ87F0aJnFhIc7h2FurvjRwlqghnJhOkx2rVK8T/vEGm41s/pzzNNOF4+VCcMUcLp6jBiagkWLOZIQhLarI6eIwkwtqUVTMleKtfXifdZsO7ajQfruutZllHFc7gHC7BgxtowT20oQMYpvAMr/Bm5daL9W59LEcrVrlzCn9gff4A+E2T0g== w2 t Selection Step 4 agent to select devices that achieve a higher test accuracy ! . ⌅ … t Reporting (k) controls how fast reward rt grows with !t. In general, as ML

w … (k) AAAB/3icbVDNS8MwHE3n15xfVcGLl+IQ5mW0m6DHgRePE9wHbLWkabqFpUlJUmXUHvxXvHhQxKv/hjf/G9OtB918EPJ47/cjL8+PKZHKtr+N0srq2vpGebOytb2zu2fuH3QlTwTCHcQpF30fSkwJwx1FFMX9WGAY+RT3/MlV7vfusZCEs1s1jbEbwREjIUFQackzj4Y+p4GcRvpKH7K7tDY5y7ymZ1btuj2DtUycglRBgbZnfg0DjpIIM4UolHLg2LFyUygUQRRnlWEicQzRBI7wQFMGIyzddJY/s061ElghF/owZc3U3xspjGQeUU9GUI3lopeL/3mDRIWXbkpYnCjM0PyhMKGW4lZehhUQgZGiU00gEkRntdAYCoiUrqyiS3AWv7xMuo2606w3bs6rrUZRRxkcgxNQAw64AC1wDdqgAxB4BM/gFbwZT8aL8W58zEdLRrFzCP7A+PwBKbGWIQ== 3 w k [K] training proceeds, the model accuracy will increase at a slower {AAACDHicbVDLSsNAFJ34rPVVdelmsAh1U5JW0GXBjeCmgn1AEstkMmmHTiZhZqKUmA9w46+4caGIWz/AnX/jpM1CWy8MczjnXO69x4sZlco0v42l5ZXVtfXSRnlza3tnt7K335VRIjDp4IhFou8hSRjlpKOoYqQfC4JCj5GeN77I9d4dEZJG/EZNYuKGaMhpQDFSmhpUqk7qeBHz5STUX3qfDZq3aW18kj2MHcqhfeU6mXaZdXNacBFYBaiCotqDypfjRzgJCVeYISlty4yVmyKhKGYkKzuJJDHCYzQktoYchUS66fSYDB5rxodBJPTjCk7Z3x0pCmW+rXaGSI3kvJaT/2l2ooJzN6U8ThTheDYoSBhUEcyTgT4VBCs20QBhQfWuEI+QQFjp/Mo6BGv+5EXQbdStZr1xfVptNYo4SuAQHIEasMAZaIFL0AYdgMEjeAav4M14Ml6Md+NjZl0yip4D8KeMzx8UzpuQ 3 | 2 } pace, which means !t !t 1 decreases as round t increases. | | Fig. 2: The federated learning workflow with FAVO R. Therefore, we use an exponential term to amplify the marginal accuracy increase as FL progresses into later stages. ⌅ is set Note that the above workflow can be easily tweaked to let each to 64 in our experiments. selected client upload the model difference (as compared The second term, 1, encourages the agent to complete to the last version of the local model) instead of the new training in fewer rounds, because the more rounds it takes, local model itself, without changing the underlying logic. The the less cumulative reward the agent will receive. FL server can always use the to reconstruct a copy of the B. Workflow corresponding local model stored on it. When a device k has new training data, it will request the Fig. 2 shows how our system FAVO R performs federated global model weights w and perform local SGD training for learning with a DRL agent selecting devices in each round, t one epoch as Step 2, which generates new local model weights following the steps below: w(k). Then the device pushes w(k) to the FL server, and the Step 1: All N eligible devices check in with the FL server. t t FL server updates the state st to make the DRL agent aware Step 2: Each device downloads the initial random model of the device with updated data. It should be noted that data w weights init from the server, performs local SGD on each device remain unchanged during the DRL training to training for one epoch, and returns the resulting model ensure that the training follows the Markov Decision Process. weights w(k),k [N] to the FL server. { 1 2 } The overhead introduced is minimum since no extra compu- Step 3: In round t, where t =1, 2,..., upon receiving the tation/communication overhead is added to devices. The only uploaded local weights, the corresponding copies of local overhead is that now the FL server needs to store a copy model weights stored on the server are updated. The DQN of the latest local model weights from each device in order agent computes Q(st,a; ✓) for all devices a =1,...,N. to form the state st. However, the FL server is deployed in Step 4: The DQN agent selects K devices corresponding to cloud, where provisions sufficient on-demand computational the top-K values of Q(st,a; ✓), a =1,...,N. The and storage capacity . selected K devices download the latest global model weights wt and perform one epoch of SGD locally to C. Dimension Reduction (k) obtain wt+1 k [K] . One issue of the proposed DQN agent is that it uses the {(k) | 2 } Step 5: w k [K] are reported (uploaded) to the server { t+1| 2 } weights of the global model and all local models as the state, to compute wt+1 based on FEDAVG. Move into round which leads to a large state space. Many deep neural networks t +1and repeat Steps 3-5. (e.g., CNN) have millions of weights, making it challenging Steps 3-5 will be repeated until completion, e.g., until a to train the DQN agent with such a high dimensional state target accuracy is reached or after a certain number of rounds. space. To reduce the dimension of the state space, we propose -erigagrtm a eusal ic hyindirectly they since original unstable the However, be selected. can action are algorithms potential devices Q-learning which each on for based estimation value function a the learn to (DDQN) DQN Double with Agent the Training D. PCA- agent. use training DQN efficient we the enable IV, to of states Sec. dominant as in weights their evaluation model to clus- compressed the a according in observe only models Therefore, can local labels. to we of 1, dimensions Round effecting 431,080 at tering even dominant from and a reduced dimensions with two devices Even those “6”. from 1 label Round in “ models yellow local the labels. all dominant different example, with For devices on trained models local the principal second project and first we the devices, components. of 1, 100 space Round two-dimensional the a in of onto the vectors training while weight SGD 80 class, model local has dominant of client a epochs each to II: belonging 20 Sec. remaining samples in same data the experiment in its the distributed trained are as on be samples PyTorch, way Data running to dataset. devices, CNN total) MNIST 100 two-layer in the on learning weights a federated distributions model Consider with 431,080 data experiment. (with between simple model differentiate a to through weights model small. compress negligibly to is overhead rounds subsequent the of in Therefore, weights components again. model principal model PCA several the first obtain { to reused are rounds subsequent weights model the local on based for only components states principal represent ent to weights model compressed instead. the weights model use to (PCA) and analysis component principle apply to w pcfial,w opt h C odn etr fdiffer- of vectors loading PCA the compute we Specifically, epooet s h obede -erignetwork Q-learning deep double the use to propose We indicate colors) (or shapes different 3, Fig. in shown As PCA-compressed using of effectiveness the demonstrate We i.3 C fCNwihso h NS dataset. MNIST the on weights CNN of PCA 3: Fig. −0.2 −0.1 t ( C1 k t 0.1 0.2 0.3 0.4 ) 02002040.6 0.4 0.2 0 −0.2 0 ,k =1 2 , [ N i.e. % ] } , fissmlshv admlbl.Atrfive After labels. random have samples its of hog iertasomto,wtotfitting without transformation, linear through , { w 1 ( k ) ,k t =2 2 + 00 0.05 0 −0.05 Q C0 [ in eoetecompressed the denote signs ” N , ⇤ 3 ( ] ,... } s bandi tp2 In 2. Step in obtained , t ,a uhlaigvectors loading such , ) -erigprovides Q-learning . { w 1 ( a −0.10 −0.05 0 k ) ,k tstate at 2 [ % N s ] of } t , , ttsaefdit h n ftedul DQNs double the action of an one generates the DQN the 4, Fig. The into in fed shown As are states. states the initialize to selection device to prone less is which evaluation, every“jittering.” action-value frozen the is to DDQN network stability behind the idea that The is estimation. function function action-value value the another adds approximator DDQN an learning by Q performance agent the optimize sn h etn e rmec aae.Orexperimental Our 49 dataset. to each up by from F set that show models results testing trained the the of accuracy using the evaluated and We FashionMNIST, comparison. F MNIST, with datasets: CIFAR-10, benchmark three on open- threads, lightweight an models. with PyTorch as real-world devices running of released each number have large a we simulate which Python the With code, project. source of lines 2000 where minimize to function updated action-value The estimation. value update to the functions acton-value which two uses (6) where learns agent as: the (4) which the with has solve pairs, agent to action-state DRL the few training, a FL sampled of rounds several After server. i.4 h DNAetitrcigwt h Lserver. FL the with interacting Agent DDQN The 4: Fig. ✓ ( otanteDLaet h Lsre rtprom random performs first server FL the agent, DRL the train To eeautdF evaluated We F implemented have We s Y t +1 ✓ ,a t DoubleQ 0 Agent t ; ↵ Y ✓ = ✓ t t stefoe aaeest d tblt oaction- to stability add to parameters frozen the is Features DoubleQ t sasaa tpsize. step scalar a is ✓ ) steoln aaeesudtdprtm tp and step, time per updated parameters online the is t % `

oteotmlato-au function action-value optimal the to +

t := ( … = ✓ nteMIT pt 23 to up MNIST, the on ↵ t

(

)=( r r stetre tround at target the is Y V R AVO … t t

t V R AVO

ED DoubleQ + + … Q AAAB8XicbVBNS8NAEN3Ur1q/qh69LBbBU0mqoMeCF48t2A9sQ9lsJ+3SzSbsToQS+i+8eFDEq//Gm//GbZuDtj4YeLw3w8y8IJHCoOt+O4WNza3tneJuaW//4PCofHzSNnGqObR4LGPdDZgBKRS0UKCEbqKBRYGETjC5m/udJ9BGxOoBpwn4ERspEQrO0EqPzUHWxzEgmw3KFbfqLkDXiZeTCsnRGJS/+sOYpxEo5JIZ0/PcBP2MaRRcwqzUTw0kjE/YCHqWKhaB8bPFxTN6YZUhDWNtSyFdqL8nMhYZM40C2xkxHJtVby7+5/VSDG/9TKgkRVB8uShMJcWYzt+nQ6GBo5xawrgW9lbKx0wzjjakkg3BW315nbRrVe+qWmteV+q1PI4iOSPn5JJ45IbUyT1pkBbhRJFn8kreHOO8OO/Ox7K14OQzp+QPnM8fwqaQ7A== ` V E IV. State Y A Q t

(

t VG AAAB8nicbVBNS8NAEJ3Ur1q/qh69LBbRU0mqoMeCF48t2A9oQ9lsN+3SzSbsToQS+jO8eFDEq7/Gm//GbZuDtj4YeLw3w8y8IJHCoOt+O4WNza3tneJuaW//4PCofHzSNnGqGW+xWMa6G1DDpVC8hQIl7yaa0yiQvBNM7ud+54lrI2L1iNOE+xEdKREKRtFKveYg6+OYI72cDcoVt+ouQNaJl5MK5GgMyl/9YczSiCtkkhrT89wE/YxqFEzyWamfGp5QNqEj3rNU0YgbP1ucPCMXVhmSMNa2FJKF+nsio5Ex0yiwnRHFsVn15uJ/Xi/F8M7PhEpS5IotF4WpJBiT+f9kKDRnKKeWUKaFvZWwMdWUoU2pZEPwVl9eJ+1a1buu1po3lXotj6MIZ3AOV+DBLdThARrQAgYxPMMrvDnovDjvzseyteDkM6fwB87nDyaLkR0= Q max DoubleQ ✓ a euetecmuiainrounds communication the reduce can … ( t ✓ a ytann oua N models CNN popular training by ) s 0 s AAAB/XicbVDNS8MwHE3n15xf9ePmJTgEL452CnocePE4wX3AVkqapltYmpQkFWYp/itePCji1f/Dm/+N6daDbj4Iebz3+5GXFySMKu0431ZlZXVtfaO6Wdva3tnds/cPukqkEpMOFkzIfoAUYZSTjqaakX4iCYoDRnrB5Kbwew9EKir4vZ4mxIvRiNOIYqSN5NtHw0CwUE1jc2Uq9zN97ua+XXcazgxwmbglqYMSbd/+GoYCpzHhGjOk1MB1Eu1lSGqKGclrw1SRBOEJGpGBoRzFRHnZLH0OT40SwkhIc7iGM/X3RoZiVQQ0kzHSY7XoFeJ/3iDV0bWXUZ6kmnA8fyhKGdQCFlXAkEqCNZsagrCkJivEYyQR1qawminBXfzyMuk2G+5Fo3l3WW81yzqq4BicgDPggivQAregDToAg0fwDF7Bm/VkvVjv1sd8tGKVO4fgD6zPHw6klY4= ALUTION O I UAT L VA t n -etra h rusof groups the as K-Center and ygain descent, gradient by t , V R AVO Q argmax Reward Reward Q 1 a (

threading softmax s ( oslc eiefrteFL the for device a select to t a s +1 t Action Q ,a ihPTrhi around in PyTorch with M ,a ( Q AAAB6nicbVBNS8NAEJ3Ur1q/qh69LBbBU0mKoMeCF48VTVtoQ9lsN+3SzSbsToQS+hO8eFDEq7/Im//GbZuDtj4YeLw3w8y8MJXCoOt+O6WNza3tnfJuZW//4PCoenzSNkmmGfdZIhPdDanhUijuo0DJu6nmNA4l74ST27nfeeLaiEQ94jTlQUxHSkSCUbTSgx7goFpz6+4CZJ14BalBgdag+tUfJiyLuUImqTE9z00xyKlGwSSfVfqZ4SllEzriPUsVjbkJ8sWpM3JhlSGJEm1LIVmovydyGhszjUPbGVMcm1VvLv7n9TKMboJcqDRDrthyUZRJggmZ/02GQnOGcmoJZVrYWwkbU00Z2nQqNgRv9eV10m7UPbfu3V/Vmo0ijjKcwTlcggfX0IQ7aIEPDEbwDK/w5kjnxXl3PpatJaeYOYU/cD5/AGNIjc0= r ; s Q t ; ✓ t t % ( pae.DQ adds DDQN updates. ✓ ,a ( t s s endas defined a AAAB6nicbVBNS8NAEJ3Ur1q/qh69LBbBU0lqQY8FLx4r2g9oQ9lsN+3SzSbsToQS+hO8eFDEq7/Im//GbZuDtj4YeLw3w8y8IJHCoOt+O4WNza3tneJuaW//4PCofHzSNnGqGW+xWMa6G1DDpVC8hQIl7yaa0yiQvBNMbud+54lrI2L1iNOE+xEdKREKRtFKD3SAg3LFrboLkHXi5aQCOZqD8ld/GLM04gqZpMb0PDdBP6MaBZN8VuqnhieUTeiI9yxVNOLGzxanzsiFVYYkjLUthWSh/p7IaGTMNApsZ0RxbFa9ufif10sxvPEzoZIUuWLLRWEqCcZk/jcZCs0ZyqkllGlhbyVsTDVlaNMp2RC81ZfXSbtW9a6qtft6pVHL4yjCGZzDJXhwDQ24gya0gMEInuEV3hzpvDjvzseyteDkM6fwB87nD0qwjcA= )) t ,a ) t t nFashionMNIST, on ; irr,F library, r ,a ✓ ; Q t ✓ ✓ ; )) Environment t FL server ( ✓ 0 Q t s 2 Y t ) , t ); Q ( i.e. t +1 DoubleQ s ostabilize to ✓ ( t V R AVO s ,a , Q ,a 0 t t ) ,a ⇤ ; ; .

( ✓ ✓

s ; t … t ,a ✓ in , ) ) can (6) (5) , t is ) ) . . erigjbaded hntejbcnegst h target the to converges job the federated a when of ends initialization and the at job starts learning episode An tasks. dimensions FL 100 to 431,080 by from minutes takes weights GPU. on model seconds EC2 takes AWS the Reducing iteration training an selecting each on and of lightweight, agent DRL probability the through the instance train passing We becomes output output device. Each layer particular The 100. is dimensions. softmax layer 100 a second to the of reduced size ( devices) weights 100 model 101 from have weights The we global states. where the hidden 10,100, 512 is with size networks, input MLP consists two-layer agent DRL two the of in model DDQN The devices. available agent DRL the Training F A. of communication metric of performance number the the as use rounds we cru- Thus, is mobile rounds important. communication of cially of bandwidth number the network reducing and devices, capacity computation ited perfor- best the to leading F of hyper-parameters mance datasets the different on chose models and CNN the for hyper-parameters of settings and methodology follows. our as describe briefly 42 We algorithm. to up and etanteDLaeto ifrn aaeswt 100 with datasets different on agent DRL the train We metrics Performance aaesadmodels and Datasets i.5postetann rgeso h R gn nthree on agent DRL the of progress training the plots 5 Fig. −100 • • • Total Return −50 a oln.O ahdvc,tebthsz s5 and 50 is size five. batch is the 2 number by device, epoch followed each the layer On and each pooling. channels with output max 16, six having has second layer the first The layers. and lution 100 is five. size CIFAR-10 is batch 2 number the by epoch device, followed the each layer each On with pooling. 32, max channels has output second 16 the has and layer the first The and layers. convolution ten is size FashionMNIST five. batch is the number 2 device, epoch by each and followed layer channels On each output pooling. with 20 50, has has second layer the first The layers. lution MNIST 0 050100150 a R riigo h MNIST. the on training DRL (a) p2.2xlarge ED etanaCNmdlwt w 5 two with model CNN a train We . A % etanaCNmdlwt w 5 two with model CNN a train We . VG sklearn.decomposition.PCA nCFR1,cmae oteF the to compared CIFAR-10, on . w Episode etanaCNmdlwt w 5 two with model CNN a train We . ihaK0GU h DNmdlis model DDQN The GPU. K80 a with n h oa weights local the and nfdrtdlann,det h lim- the to due learning, federated In : eepoe ifrn combinations different explored We : V R AVO { . Total Return w −400 −200 i.5 riigteDLagent. DRL the Training 5: Fig. ( b R riigo FashionMNIST. on training DRL (b) . k 0 ) 050100150200 | ⇥ k ⇥ convo- 5 convo- 5 2 ⇥ ED [100] max 2 A i.e. ⇥ ⇥ ⇥ VG } 2 2 5 , Episode erigo h he ecmrsi i.6 i.7 n i.8. Fig. and 7, Fig. 6, Fig. in the benchmarks plot three We the labels. on learning two to and belong accuracy evenly IID test-set device is each device on each data one data the 50 remaining 20 0 remaining the and label label, of number the round, each At devices comparison. selected a as K-Center and F of efficiency Data Non-IID of on Levels Different training for B. 85% CIFAR-10. MNIST, on training the for on 55% training and for FashionMNIST, 99% to set reward accuracy data. MNIST non-IID of levels Accuracy 6: Fig. . 5 =1 euse We the testify to data non-IID of levels different investigate We niaigta 50 that indicating

FL Accuracy (%) FL Accuracy (%)100 96 97 98 99 60 80 . 0 Communication Round (#) Round Communication (#) Round Communication R 0200400 niaigta aao ahdvc nybln oone to belong only device each on data that indicating =0 Baseline ⌦ bandi n psd.Tetre accuracy target The episode. one in obtained h oa euni h uuaiediscounted cumulative the is return total The . odnt h ordfeetlvl fnnIDdata: non-IID of levels different four the denote to . 8 % 010150 100 50 V R AVO niaigta 80 that indicating aabln oohrlabels, other to belong data σ=0.8 K σ=1 v.s. v.s. sstt 0 si [1]. in as 10, to set is eicuetepromneo F of performance the include We . % h omncto onso federated of rounds communication the −120 Total−100 Return Favor −80 −60 −40 −20 omncto onso different on rounds communication ftedt eogt n ae n the and label one to belong data the of 050100150200 % c R riigo CIFAR-10. on training DRL (c) aabln oohrlabels, other to belong data FL Accuracy (%) FL Accuracy (%) % 94 96 98 96 97 98 99 Communication Round (#) Round Communication (#) Round Communication 0200 ftedt eogt one to belong data the of K-Center 04 60 40 20 Episode = σ=0.5 H σ=H =0 niaigthat indicating FedAvg indicating ED A ⌦ VG = is agtacrc o F reach for to accuracy rounds target communication a of number The I: TABLE F outperform always can Center 50 of accuracy of test-set distribution a data to with converges CIFAR-10 on trained of model distribution data with and of FashionMNIST test- distribution on the data trained than with MNIST accuracy on lower model a The indicate to accuracy. numbers converges set italic model the the that noted 99 that be of should 85 accuracy It MNIST, CIFAR-10. test-set on a CNN achieve the to necessary rounds data. FashionMNIST non-IID of levels Accuracy 7: Fig. =1 u xeietlrslsso hr sn urne htK- that guarantee no is there show results experimental Our communication of number the shows I Table in entry Each -etr056 152 74 21 62 67 0.5 126 K-Center 0.8 K-Center K-Center 1.0 K-Center F F F F F

FL Accuracy (%) FL Accuracy (%) 100 ED ED ED ED ED 65 70 75 80 85 50 F F F F . = 0 0 V R AVO R AVO R AVO R AVO Communication Round (#) Round Communication (#) Round Communication A A A A A 02040 0200400 ovre oats-e cuayo 96 of accuracy test-set a to converges VG VG VG VG VG H Baseline ohcneg oats-e cuayo 80 of accuracy test-set a to converge both ID 51 47 14 55 (IID) 0 . 91 50 67 61 16 87 19 43 59 52 59 0.5 113 0.5 221 0.8 0.8 1.0 1.0 H H H σ=0.8 σ=1 v.s. NS ahoMITCIFAR-10 FashionMNIST MNIST Favor omncto onso different on rounds communication 5711 1714 1811 1517 2219 1383 1871 1497 2132 1232 1684 272 421 313 V R AVO % o ahoMIT n 55 and FashionMNIST, for ED v.s. FL Accuracy (%) FL Accuracy (%) 40 60 80 20 40 60 80 0 A Communication Round (#) Round Communication (#) Round Communication 01020 0100200300 VG F 1134 1593 1340 KCenter ED % eie eetdfrom selected Devices . . A VG σ=0.5 σ=H n K-Center. and % h models The . FedAvg 114 188 198 % % =1 % =1 The . for for . . 0 0 i.8 Accuracy 8: Fig. e on hnw ri h w-ae N nMITwith MNIST on CNN two-layer the train we when round per Updates Weight and Selection Device C. 42 algorithm. to up and 49 to up F by federated by to randomly bias F selected ever, more devices introduce than may learning clusters K-Center the MNIST. with training FL of weights w model on PCA 9: Fig. data. CIFAR-10 non-IID of levels 1 esv h lblmdlwihsadlclmdlweights model local and weights model global the save We =0 , w

−0.5 FL Accuracy (%) FL Accuracy (%) C2 −0.5 0.5 1.0 1.5 20 30 40 50 40 C2 20 0.5 1.0 1.5 2 V R AVO 0 0 Communication Round (#) Round Communication (#) Round Communication 0 ..., . 0100200 8 Baseline h ae oe egt r eue otwo- to reduced are weights model saved The . . . . . 3.0 2.5 2.0 1.5 1.0 . . . . 3.0 2.5 2.0 1.5 1.0 % 04 080 60 40 20 w a eue h ubro omncto rounds communication of number the reduced has w w 5 4 % nteMIT pt 23 to up MNIST, the on a riigo NS ihF with MNIST on Training (a) Local weights Local b riigo NS ihF with MNIST on Training (b) Local weights Local 5 r h lblmdlwihsa on 1-5. Round at weights model global the are σ=0.8 σ=1 nCFR1,cmae oteF the to compared CIFAR-10, on w v.s. w 4 3 Favor omncto onso different on rounds communication w 3 w 2 C1 C1 w

Global weights FL Accuracy (%) FL Accuracy (%) 2 Global weights 20 40 60 40 45 50 55 Communication Round (#) Round Communication (#) Round Communication 0100200 KCenter 04 60 40 20 w % ED V R AVO 1 A nFashionMNIST, on w VG σ=0.5 1 σ=H . . ED A FedAvg VG ED How- . w w init init A VG K=10 K=50 K=100 and sketched updates to decrease the completion time of FedAvg each communication round. Bonawitz et al. [3] developed a 60 secure aggregation protocol that enables a server to perform the computation of high-dimensional data from the devices, 40 which is widely deployed in production environments [4, K = 10, round # = 87 20 K = 50, round # = 89 17]. McMahan et al. [1] presented the Federated Averaging

Accuracy(%) K = 100, round # = 128 (FEDAVG) algorithm that allows devices to perform local 020406080100120140 training of multiple epochs, which further reduces the number Communication Round (#) of communication rounds by averaging model weights from KCenter the client devices. Nishio et al. [18] proposed a resource-aware 60 selection algorithm that maximizes the number of participating 40 devices in each round. Sattler et al. [9] proposed a compression K = 10, round # = 74 framework, Sparse Ternary Compression (STC), that reduces 20 K = 50, round # = 76 the communication costs and remains robust to non-IID data.

Accuracy(%) K = 100, round # = 128 FAVO R applies DRL to select the best subset of participating 020406080100120140 devices to minimize the number of communication rounds, Communication Round (#) which is orthogonal to these studies on communication effi- Favor 60 ciency. Sample efficiency. Unlike centralized machine learning, 40 federated learning performs training on non-IID data on de- K = 10, round # = 61 vices. Zhao et al. [8] presented a mathematical demonstration K = 50, round # = 78 20 to show that non-IID data reduces the accuracy of federated Accuracy(%) K = 100, round # = 128 learning by a substantial margin, and proposed to push a 020406080100120140 Communication Round (#) small set of uniform distributed data to participating devices. Downloading extra data further increases communication cost Fig. 10: FL training on CIFAR-10 with different levels of and computation workload for the devices. Mehryar et al. [10] parallelism. proposed an agnostic federated learning framework for fairness to avoid biases due to non-IID data from the devices. In dimension vectors by PCA and plotted in Fig. 9. By examining contrast, FAVO R is the first work that counterbalances the bias the model weights trained with FEDAVG and FAVO R, respec- from different non-IID data by dynamically constructing the tively, Fig. 9 shows that FAVO R updates the global model with subset of participating devices with DRL techniques. a larger weight update than FEDAVG in early rounds, which leads to a faster convergence speed. VI. CONCLUDING REMARKS In this paper, we presented our design and implementation D. Increasing Parallelism of FAVO R, an experience-driven federated learning framework We also studied the performance of FAVO R on different that performs active device selection to minimize the com- numbers of selected devices. The CIFAR-10 dataset is dis- munication rounds of FL training. We argue that non-IID tributed to 100 devices with the non-IID level at =0.8. data exacerbates the divergence of model weights on partic- We apply FAVO R,FEDAVG, and K-Center to train the same ipating devices, and increases the number of communication CNN on CIFAR-10 separately. The number of selected devices rounds of federated learning by a substantial margin. With K is set to 10, 50, and 100 to study the performance of both mathematical demonstrations and empirical studies, we federated learning with different parallelism. Fig. 10 shows found the implicit connection between model weights and the that increasing the parallelism does not reduce the number distribution of data that the model is trained on. We proposed of communication rounds and even increases the number of to actively select a specific subset of devices to participate communication rounds. in training at each round, in order to counterbalance the bias introduced by non-IID data on each device and to speedup FL V. R ELATED WORK training by minimizing the number of communication rounds. Federated learning allows machine learning models to be In particular, we designed a DRL-based agent that applies trained on mobile devices in a distributed fashion without the DDQN algorithm to select the best subset of devices to violating user privacy. Existing studies on federated learning achieve our objectives. We have implemented an open-source have mostly focused on improving its efficiency. We classify prototype of FAVO R with PyTorch in more than 2K lines the existing literature into the following two categories: of code and evaluated it with a variety of ML models. An Communication efficiency. Mobile devices typically have extensive comparison with FL training jobs by FEDAVG has unstable and expensive connections, and existing works have shown that FL training with FAVO R has reduced the number of attempted to improve the communication efficiency of fed- communication rounds by up to 49% on the MNIST dataset, erated learning. Konecnˇ y` et al. [6, 7] proposed structured up to 23% on FashionMNIST, and up to 42% on CIFAR-10. REFERENCES [1] H. B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y. Arcas, “Communication-Efficient Learning of Deep Networks from Decentralized Data,” in Proc. the Artificial Intelligence and Statistics Conference (AISTATS), 2017. [2] V. Smith, C.-K. Chiang, M. Sanjabi, and A. S. Talwalkar, “Federated Multi-Task Learning,” in Proc. the Advances in Neural Information Processing Systems (NeurIPS), 2017. [3] K. Bonawitz, V. Ivanov, B. Kreuter, A. Marcedone, H. B. McMahan, S. Patel, D. Ramage, A. Segal, and K. Seth, “Practical Secure Aggre- gation for Privacy-Preserving Machine Learning,” in Proc. the ACM SIGSAC Conference on Computer and Communications Security (CCS), 2017. [4] K. Bonawitz, H. Eichner, W. Grieskamp, D. Huba et al., “Towards Federated Learning at Scale: System Design,” in Proc. the Conference on Systems and Machine Learning (SysML), 2019. [5] Y. Lin, S. Han, H. Mao, Y. Wang, and W. J. Dally, “Deep Gradient Compression: Reducing the Communication Bandwidth for Distributed Training,” in Proc. the International Conference on Learning Represen- tations (ICLR), 2018. [6] J. Konecnˇ y,` B. McMahan, and D. Ramage, “Federated optimization: Distributed optimization beyond the datacenter,” in Proc. the NIPS Workshop on Optimization for Machine Learning (OPT), 2015. [7] J. Konecnˇ ý, H. B. McMahan, F. X. Yu, P. Richtarik, A. T. Suresh, and D. Bacon, “Federated Learning: Strategies for Improving Communica- tion Efficiency,” in Proc. the NIPS Workshop on Private Multi-Party Machine Learning, 2016. [8] Y. Zhao, M. Li, L. Lai, N. Suda, D. Civin, and V. Chandra, “Federated Learning with Non-IID Data,” arXiv preprint arXiv:1806.00582, 2018. [9] F. Sattler, S. Wiedemann, K.-R. Müller, and W. Samek, “Robust and Communication-Efficient Federated Learning from Non-IID Data,” arXiv preprint arXiv:1903.02891, 2019. [10] M. Mohri, G. Sivek, and A. T. Suresh, “Agnostic Federated Learning,” in International Conference on Machine Learning (ICML), 2019. [11] O. Sener and S. Savarese, “Active Learning for Convolutional Neural Networks: A Core-Set Approach,” in Proc. the International Conference on Learning Representations (ICLR), 2018. [12] H. Mao, M. Alizadeh, I. Menache, and S. Kandula, “Resource Man- agement with Deep Reinforcement Learning,” in Proc. the 15th ACM Workshop on Hot Topics in Networks. ACM, 2016. [13] A. Mirhoseini, H. Pham, Q. Le, M. Norouzi, S. Bengio, B. Steiner, Y. Zhou, N. Kumar, R. Larsen, and J. Dean, “Device Placement Opti- mization with Reinforcement Learning,” in International Conference on Machine Learning (ICML), 2017. [14] R. S. Sutton and A. G. Barto, Reinforcement learning: An introduction. MIT press Cambridge, 1998. [15] M. Volodymyr, K. Kavukcuoglu, D. Silver, A. Graves, and I. Antonoglou, “Playing Atari with Deep Reinforcement Learning,” in Proc. NIPS Workshop, 2013. [16] H. Van Hasselt, A. Guez, and D. Silver, “Deep Reinforcement Learning with Double Q-Learning,” in Proc. the Thirtieth AAAI Conference on Artificial Intelligence (AAAI), 2016. [17] T. Yang, G. Andrew, H. Eichner, H. Sun, W. Li, N. Kong, D. Ram- age, and F. Beaufays, “Applied Federated Learning: Improving Google Keyboard Query Suggestions,” arXiv preprint arXiv:1812.02903, 2018. [18] T. Nishio and R. Yonetani, “Client Selection for Federated Learning with Heterogeneous Resources in Mobile Edge,” in Proc. the IEEE International Conference on Communications (ICC), 2019.