
Learning to Stop While Learning to Predict Xinshi Chen 1 Hanjun Dai 2 Yu Li 3 Xin Gao 3 Le Song 1 4 Abstract (a): Learning-based Algorithm design There is a recent surge of interest in designing Fixed-depth Learned Algorithm Dynamic-depth Traditional Algorithm deep architectures based on the update steps in hand-designed satisfied � … � � � criteria � traditional algorithms, or learning neural networks � update step � � (output) � � to improve and replace traditional algorithms. #$ #% �#& not satisfied While traditional algorithms have certain stop- ping criteria for outputting results at different iter- (b): Task-imbalanced Meta Learning Task 1 Task 2 ations, many algorithm-inspired deep models are �) ∇#ℒ. � restricted to a “fixed-depth” for all inputs. Similar … *+,-. to algorithms, the optimal depth of a deep architec- ∇#ℒ/ ture may be different for different input instances, �*+,-/ D either to avoid “over-thinking”, or because we want to compute less for operations converged al- Figure 1. Motivation for learning to stop. ready. In this paper, we tackle this varying depth problem using a steerable architecture, where a cept in traditional algorithms is the stopping criteria for feed-forward deep model and a variational stop- outputting the result, which can be either a convergence ping policy are learned together to sequentially condition or an early stopping rule, such stopping criteria determine the optimal number of layers for each has been more or less ignored in algorithm-inspired deep input instance. Training such architecture is very learning models. A “fixed-depth” deep model is used to challenging. We provide a variational Bayes per- operate on all problem instances (Fig.1 (a)). Intuitively, spective and design a novel and effective training for deep learning models, the optimal depth (or the opti- procedure which decomposes the task into an or- mal number of steps to operate on an input) can also be acle model learning stage and an imitation stage. different for different input instances, either because we Experimentally, we show that the learned deep want to compute less for operations converged already, or model along with the stopping policy improves we want to generalize better by avoiding “over-thinking”. the performances on a diverse set of tasks, in- Such motivation aligns well with both the cognitive science cluding learning sparse recovery, few-shot meta literature (Jones et al., 2009) and many examples below: learning, and computer vision tasks. • In learning to optimize (Andrychowicz et al., 2016; Li & Malik, 2016), neural networks are used as the optimizer 1. Introduction to minimize some loss function. Depending on the initial- arXiv:2006.05082v1 [cs.LG] 9 Jun 2020 ization and the objective function, an optimizer should Recently, researchers are increasingly interested in the con- converge in different number of steps; nections between deep learning models and traditional algo- • In learning to solve statistical inverse problems such as rithms: deep learning models are viewed as parameterized compressed sensing (Chen et al., 2018; Liu et al., 2019), algorithms that operate on each input instance iteratively, inverse covariance estimation (Shrivastava et al., 2020), and traditional algorithms are used as templates for design- and image denoising (Zhang et al., 2019), deep mod- ing deep learning architectures. While an important con- els are learned to directly predict the recovery results. 1Georgia Institute of Technology, USA 2Google Research, USA In traditional algorithms, problem-dependent early stop- 3King Abdullah University of Science and Technology, Saudi ping rules are widely used to achieve regularization for a Arabia 4Ant Financial, China. Correspondence to: Xinshi Chen variance-bias trade-off. Deep learning models for solving <[email protected]>, Le Song <[email protected]>. such problems maybe also achieve a better recovery ac- Proceedings of the 37 th International Conference on Machine curacy by allowing instance-specific computation steps; Learning, Vienna, Austria, PMLR 108, 2020. Copyright 2020 by • In meta learning, MAML (Finn et al., 2017) used an the author(s). unrolled and parametrized algorithm to adapt a common Learning to Stop While Learning to Predict < 0.5 < 0.5 ≥ 0.5 VAE-based method: �# �% �. stop, output �� max � (ℱ�, ��) � � � $ �−VAE ℱ� Alternating Updates �� � �# �� �% �� �- … �. �� … max � (ℱ�, ��) 6 �−VAE Figure 2. Two-component model: learning to predict (blue) while Our method: learning to stopping (green). oracle oracle ∗ ℱ � |� �∗|�∗ �� parameter to a new task. However, depending on the � similarity of the new task to the old tasks, or, in a more realistic task-imbalanced setting where different tasks ∗ max ��−VAE(ℱ�, � |�) minKL( , ) have different numbers of data points (Fig.1 (b)), a task- 6 � specific number of adaptation steps is more favorable to optimal �∗ optimal �∗ avoid under or over adaption. Stage I. Stage II. To address the varying depth problem, we propose to learn Figure 3. Two-stage training framework. a steerable architecture, where a shared feed-forward model for normal prediction and an additional stopping policy model and a stopping policy. We can either directly use these are learned together to sequentially determine the optimal learned models, or plug them back to the variational EM number of layers for each input instance. In our framework, framework and reiterate to further optimize both together. the model consists of (see Fig.2) Our proposed learning to stop method is a generic frame- • A feed-forward or recurrent mapping F , which trans- θ work that can be applied to a diverse range of applications. forms the input x to generate a path of features (or states) To summarize, our contribution in this paper includes: x1; ··· ; xT ; and • A stopping policy πφ :(x; xt) 7! πt 2 [0; 1], which se- 1. a variational Bayes perspective to understand the pro- quentially observes the states and then determines the posed model for learning both the predictive model and probability of stopping the computation of Fθ at layer t. the stopping policy together; 2. a principled and efficient algorithm for jointly learning These two components allow us to sequentially predict the the predictive model and the stopping policy; and the next targeted state while at the same time determining when relation of this algorithm to reinforcement learning; to stop. In this paper, we propose a single objective function 3. promising experiments on various tasks including learn- for learning both θ and φ, and we interpret it from the per- ing to solve sparse recovery problems, task-imbalanced spective of variational Bayes, where the stopping time t is few-shot meta learning, and computer vision tasks, where viewed as a latent variable conditioned on the input x. With we demonstrate the effectiveness of our method in terms this interpretation, learning θ corresponds to maximizing of both the prediction accuracy and inference efficiency. the marginal likelihood, and learning φ corresponds to the inference step for the latent variable, where a variational distribution qφ(t) is optimized to approximate the posterior. 2. Related Works A natural algorithm for solving this problem could be the Unrolled algorithm. A line of recent works unfold and Expectation-Maximization (EM) algorithm, which can be truncate iterative algorithms to design neural architectures. very hard to train and inefficient. These algorithm-based deep models can be used to automat- How to learn θ and φ effectively and efficiently? We propose ically learn a better algorithm from data. This idea has been a principled and effective training procedure, where we demonstrated in different problems including sparse signal decompose the task into an oracle model learning stage and recovery (Gregor & LeCun, 2010; Sun et al., 2016; Borg- an imitation learning stage (Fig.3). More specifically, erding et al., 2017; Metzler et al., 2017; Zhang & Ghanem, 2018; Chen et al., 2018; Liu et al., 2019), sparse inverse • During the oracle model learning stage, we utilize a covariance estimation (Shrivastava et al., 2020), sequential closed-form oracle stopping distribution q∗jθ which can Bayesian inference (Chen et al., 2019), parameter learning leverage label information not available at testing time. in graphical models (Domke, 2011), non-negative matrix • In the imitation learning stage, we use a sequential policy factorization (Yakar et al., 2013), etc. Unrolled algorithm π to mimic the behavior of the oracle policy obtained in φ based deep module has also be used for structured prediction the first stage. The sequential policy does not have access (Belanger et al., 2017; Ingraham et al., 2019; Chen et al., to the label so that it can be used during testing phase. 2020). Before the training phase, all these works need to This procedure provides us a very good initial predictive assign a fixed number of iterations that is used for every Learning to Stop While Learning to Predict input instance regardless of their varying difficulty level. 3.1. Steerable Model Our proposed method is orthogonal and complementary to The predictive model, F , is a typical T -layer deep model all these works, by taking the variety of the input instances θ that generates a path of embeddings (x ; ··· ; x ) through: into account via adaptive stopping time. 1 T = 1; ··· ;T Meta learning. Optimization-based meta learning techniq- Predictive model: xt = fθt (xt−1); for t (1) ues are widely applied for solving challenging few-shot where the initial x is determined by the input x. We denote learning problems (Ravi & Larochelle, 2017; Finn et al., 0 it by F = ff ; ··· ; f g where θ 2 Θ are the parameters. 2017; Li et al., 2017). Several recent advances proposed θ θ1 θT Standard supervised learning methods learn θ by optimizing task-adaptive meta-learning models which incorporate task- an objective estimated on the final state x . In our model, specific parameters (Qiao et al., 2018; Lee & Choi, 2018; T the operations in Eq.1 can be stopped earlier, and for differ- Na et al., 2020) or task-dependent metric scaling (Oreshkin ent input instance x, the stopping time t can be different.
Details
-
File Typepdf
-
Upload Time-
-
Content LanguagesEnglish
-
Upload UserAnonymous/Not logged-in
-
File Pages15 Page
-
File Size-