Optimizing Federated Learning on Non-IID Data with Reinforcement Learning

Optimizing Federated Learning on Non-IID Data with Reinforcement Learning

Optimizing Federated Learning on Non-IID Data with Reinforcement Learning Hao Wang1, Zakhary Kaplan1, Di Niu2 and Baochun Li1 1University of Toronto, {haowang, bli}@ece.utoronto.ca, [email protected] 2University of Alberta, [email protected] Abstract—The widespread deployment of machine learning However, unlike in a server-based environment, distributed applications in ubiquitous environments has sparked interests in machine learning on mobile devices is faced with a few exploiting the vast amount of data stored on mobile devices. To fundamental challenges, such as the limited connectivity of preserve data privacy, Federated Learning has been proposed to learn a shared model by performing distributed training locally wireless networks, unstable availability of mobile devices, and on participating devices and aggregating the local models into the non-IID distributions of local datasets which are hard to a global one. However, due to the limited network connectivity characterize statistically [1]–[4]. Due to these reasons, it is not of mobile devices, it is not practical for federated learning to as practical to perform machine learning on all participating perform model updates and aggregation on all participating devices simultaneously in a synchronous or semi-synchronous devices in parallel. Besides, data samples across all devices are usually not independent and identically distributed (IID), posing manner as is in a parameter-server cluster. Therefore, as additional challenges to the convergence and speed of federated a common practice in federated learning, only a subset of learning. devices are randomly selected to participate in each round In this paper, we propose FAVO R, an experience-driven control of model training to avoid long-tailed waiting times due to framework that intelligently chooses the client devices to partic- unstable network conditions and straggler devices [1, 5]. ipate in each round of federated learning to counterbalance the Since the completion time of federated learning is dom- bias introduced by non-IID data and to speed up convergence. Through both empirical and mathematical analysis, we observe inated by the communication time, Konecnˇ y` et al. [6, 7] an implicit connection between the distribution of training data proposed structured and sketched updates to decrease the on a device and the model weights trained based on those time it takes to complete a communication round. McMa- data, which enables us to profile the data distribution on that han et al. [1] presented the Federated Averaging (FEDAVG) device based on its uploaded model weights. We then propose algorithm, which aims to reduce the number communication a mechanism based on deep Q-learning that learns to select a subset of devices in each communication round to maximize a rounds required by averaging model weights from client de- reward that encourages the increase of validation accuracy and vices instead of applying traditional gradient descent updates. penalizes the use of more communication rounds. With extensive FEDAVG has also been deployed in a large-scale system [4] experiments performed in PyTorch, we show that the number to test its effectiveness. of communication rounds required in federated learning can However, existing federated learning methods have not be reduced by up to 49% on the MNIST dataset, 23% on FashionMNIST, and 42% on CIFAR-10, as compared to the solved the statistical challenges posed by heterogeneous local Federated Averaging algorithm. datasets. Since different users have different device usage patterns, the data samples and labels located on any in- I. INTRODUCTION dividual device may follow a different distribution, which cannot represent the global data distribution. It has been As the primary computing resource for billions of users, recently pointed out that the performance of federated learning, mobile devices are constantly generating massive volumes especially FEDAVG, may significantly degrade in the presence of data, such as photos, voices, and keystrokes, which are of non-IID data, in terms of the model accuracy and the of great value for training machine learning models. How- communication rounds required for convergence [8]–[10]. In ever, due to privacy concerns, collecting private data from particular, FEDAVG randomly selects a subset of devices in mobile devices to the cloud for centralized model training each round and averages their local model weights to update is not always possible. To efficiently utilize end-user data, the global model. The randomly selected local datasets may Federated Learning (FL) has emerged as a new paradigm of not reflect the true data distribution in a global view, which distributed machine learning that orchestrates model training inevitably incurs biases to global model updates. Furthermore, across mobile devices [1]. With federated learning, locally the models locally trained on non-IID data can be significantly trained models are communicated to a server for aggregation, different from each other. Aggregating these divergent models without collecting any raw data from users. Federated learning can slow down convergence and substantially reduce the model has enabled joint model training over privacy-sensitive data accuracy [8]. in a wide range of applications, including natural language In this paper, we present the design and implementation processing, computer vision, and speech recognition. of FAVO R, a control framework that aims to improve the This work is supported by a research contract with Huawei Technologies performance of federated learning through intelligent device Co. Ltd. and a NSERC Discovery Research Program. selection. Based on reinforcement learning, FAVO R aims to accelerate and stabilize the federated learning process by denote the prediction function, where = z C z = S { | i=1 i learning to actively select the best subset of devices in each 1,zi > 0, i [C] . That is, the vector valued function f 8 2 } P communication round that can counterbalance the bias intro- yields a probability vector z for each sample x, where fi duced by non-IID data. predicts the probability that the sample belongs to the ith class. Since privacy concerns have ruled out any possibility of Let the vector w denote model weights. For classification, accessing the raw data on each device, it is impossible to the commonly used training loss is cross entropy, defined as profile the data distribution on each device directly. However, C we observe that there is an implicit connection between the `(w):=Ex,y p 1y=i log fi(x, w) distribution of the training samples on a device and the ⇠ i=1 model weights trained based on those samples. Therefore, C X our main intuition is to use the local model weights and the = p(y = i)Ex y=i[log fi(x, w)], | shared global model as states to judiciously select devices i=1 that may contribute to global model improvement. We propose X The learning problem is to solve the following optimization a reinforcement learning agent based on a Deep Q-Network problem: (DQN), which is trained through a Double DQN for increased efficiency and robustness. We carefully design the reward C signal in order to aggressively improve the global model minimizew p(y = i)Ex y=i[log fi(x, w)]. (1) | i=1 accuracy per round as well as to reduce the total number of X communication rounds required. We also propose a practical In federated learning, suppose there are N client devices scheme that compresses model weights to reduce the high in total. The kth device has m(k) data samples following dimensionality of the state space, especially for big models, the data distribution p(k), which is a joint distribution of the e.g., deep neural network models. samples x,y on this device. In each round t, K devices { } We have implemented FAVO R in a federated learning simu- are selected (at random in the original FL), each of which 1 lator developed from scratch using PyTorch , and evaluated it downloads the current global model weights wt 1 from the − under a variety of federated learning tasks. Our experimental server and performs the following stochastic gradient descent results on the MNIST, FashionMNIST, and CIFAR-10 datasets (SGD) training locally: have shown that FAVO R can reduce the required number of (k) communication rounds in federated learning by up to 49% on wt = wt 1 ⌘ `(wt 1) (2) − − r − the MNIST, by up to 23% on FashionMNIST, and by up to C (k) 42% on CIFAR-10, as compared to the FEDAVG algorithm. = wt 1 ⌘ p (y = i) wEx y=i[log fi(x, wt 1)], − − r | − i=1 II. BACKGROUND AND MOTIVATION X where ⌘ is the learning rate. Note that in such notation with In this section, we briefly introduce how federated learning expectations, w(k) can be a result one or multiple epochs of works in general and show how non-IID data poses challenges t local SGD training [1]. to existing federated learning algorithms. We demonstrate how (k) Once wt is obtained, each participating device k reports to properly select client devices at each round to improve the (k) performance of federated learning on non-IID data. a model weight difference ∆t to the FL server [1], which is defined as A. Federated Learning (k) (k) (k) ∆t := wt wt 1. Federated learning (FL) trains a shared global model by − − (k) iteratively aggregating model updates from multiple client de- Devices can also upload their local models wt directly, (k) vices, which may have slow and unstable network connections. but transferring ∆t is more amenable to compression and Initially, eligible client devices first check-in with a remote communication reduction [4]. After the FL server has collected server. The remote server then proceeds federated learning updates from all K participating devices in round t, it performs synchronously in rounds. In each round, the server randomly the federated averaging (FEDAVG) algorithm to update the selects a subset of available client devices to participate in global model: training.

View Full Text

Details

  • File Type
    pdf
  • Upload Time
    -
  • Content Languages
    English
  • Upload User
    Anonymous/Not logged-in
  • File Pages
    10 Page
  • File Size
    -

Download

Channel Download Status
Express Download Enable

Copyright

We respect the copyrights and intellectual property rights of all users. All uploaded documents are either original works of the uploader or authorized works of the rightful owners.

  • Not to be reproduced or distributed without explicit permission.
  • Not used for commercial purposes outside of approved use cases.
  • Not used to infringe on the rights of the original creators.
  • If you believe any content infringes your copyright, please contact us immediately.

Support

For help with questions, suggestions, or problems, please contact us