A Sample Complexity Separation between Non-Convex and Convex Meta-Learning
Nikunj Saunshi 1 Yi Zhang 1 Mikhail Khodak 2 Sanjeev Arora 13
Abstract complexity of an unseen but related test task. Although there is a long history of successful methods in meta-learning One popular trend in meta-learning is to learn and the related areas of multi-task and lifelong learning from many training tasks a common initialization (Evgeniou & Pontil, 2004; Ruvolo & Eaton, 2013), recent that a gradient-based method can use to solve a approaches have been developed with the diversity and scale new task with few samples. The theory of meta- of modern applications in mind. This has given rise to learning is still in its early stages, with several re- simple, model-agnostic methods that focus on learning a cent learning-theoretic analyses of methods such good initialization for some gradient-based method such as as Reptile (Nichol et al., 2018) being for convex stochastic gradient descent (SGD), to be run on samples models. This work shows that convex-case analy- from a new task (Finn et al., 2017; Nichol et al., 2018). sis might be insufficient to understand the success These methods have found widespread applications in a of meta-learning, and that even for non-convex variety of areas such as computer vision (Nichol et al., 2018), models it is important to look inside the optimiza- reinforcement learning (Finn et al., 2017), and federated tion black-box, specifically at properties of the learning (McMahan et al., 2017). optimization trajectory. We construct a simple meta-learning instance that captures the problem Inspired by their popularity, several recent learning-theoretic of one-dimensional subspace learning. For the analyses of meta-learning have followed suit, eschewing cus- convex formulation of linear regression on this tomization to specific hypothesis classes such as halfspaces instance, we show that the new task sample com- (Maurer & Pontil, 2013; Balcan et al., 2015) and instead plexity of any initialization-based meta-learning favoring the convex-case study of gradient-based algorithms algorithm is ⌦(d), where d is the input dimension. that could potentially be applied to deep neural networks In contrast, for the non-convex formulation of a (Denevi et al., 2019; Khodak et al., 2019). This has yielded two layer linear network on the same instance, results showing that meta-learning an initialization by using we show that both Reptile and multi-task rep- methods similar to Reptile (Nichol et al., 2018) for con- resentation learning can have new task sample vex models leads to a reduction in sample complexity of complexity of (1), demonstrating a separation unseen tasks. These benefits are shown using natural no- O from convex meta-learning. Crucially, analyses tions of task-similarity like the average distance between of the training dynamics of these methods reveal the risk minimizers of tasks drawn from an underlying meta- that they can meta-learn the correct subspace onto distribution. A good initialization in these models is one which the data should be projected. that is close to the population risk minimizers for tasks in this meta-distribution. In this paper we argue that, even in some simple settings, 1. Introduction such convex-case analyses are insufficient to understand We consider the problem of meta-learning, or learning-to- the success of initialization-based meta-learning algorithms. learn (Thrun & Pratt, 1998), in which the goal is to use For this purpose, we pose a simple instance for meta- the data from numerous training tasks to reduce the sample learning linear regressors that share a one-dimensional sub- space, for which we prove a sample complexity separa- 1Princeton University, Princeton, New Jersey, USA 2Carnegie tion between convex and non-convex methods. In the pro- 3 Mellon University, Pittsburgh, Pennsylvania, USA Institute for cess, we provide a first theoretical demonstration of how Advanced Study, Princeton, New Jersey, USA. Correspondence to: initialization-based meta-learning methods can learn good Nikunj Saunshi
We show, in the convex formulation of linear regression 2. Related Work • on this instance, a new task sample complexity lower bound of ⌦(d) for any initialization-based meta-learning There is a rich history of theoretical analysis of learning-to- algorithm. This suggests that no amount of meta-training learn (Baxter, 2000; Maurer, 2005; Maurer et al., 2016). Our data can yield an initialization that can be used by a focus is on a well-studied setting in which tasks such as half- common gradient-based within-task algorithms to solve space learning share a common low-dimensional subspace, a new task with fewer samples than if no meta-learning with the goal of obtaining sample complexity depending had been done; thus initialization-based meta-learning on this sparse structure rather than on the ambient dimen- in the convex formulation fails to learn the underlying sion (Maurer, 2009; Maurer & Pontil, 2013; Balcan et al., task-similarity. The lower bounds also holds for more 2015; Denevi et al., 2018; Bullins et al., 2019; Khodak et al., general meta-learning instances that have the property 2019). While these works derive specialized algorithms, we that the average prediction for an input across tasks is instead focus on learning an initialization for gradient-based independent of the input, as in true in many standard methods such as SGD or few steps of gradient descent (Finn meta-learning benchmarks. et al., 2017; Nichol et al., 2018). Some of these methods have recently been studied in the convex setting (Denevi We show for the same instance that formulating the model et al., 2019; Khodak et al., 2019; Zhou et al., 2019). Our • as a two-layer linear network – an over-parameterization results show that such convex-case analyses cannot hope of the same hypothesis class – allows a Reptile-like proce- to show adaptation to an underlying low-dimensional sub- dure to use training tasks from this meta-learning instance space leading to dimension-independent sample complexity and find an initialization for gradient descent that will bounds. On the other hand, we show that their guarantees have (1) sample complexity on a new task. To the best using distance-from-initialization are almost tight for the O of our knowledge, this is the first sample complexity anal- meta-learning of convex Lipschitz functions. ysis of initialization-based meta-learning algorithms in To get around the limitations of convexity for the prob- the non-convex setting. Additionally, this demonstrates lem of meta-learning a shared subspace, we instead study that initialization-based methods can capture the strong non-convex models. While the optimization properties of property of representation learning. gradient-based meta-learning algorithms have been recently Central to our proof is a trajectory-based analysis to ana- studied in the non-convex setting (Fallah et al., 2019; Ra- • lyze properties of the solution found by a specific proce- jeswaran et al., 2019; Zhou et al., 2019), these results only dures like Reptile or gradient descent on a representation provide stationary-point convergence guarantees and do not learning objective. For the latter, we show that looking show a reduction in sample complexity, the primary goal at the trajectory is crucial as not all minimizers can learn of meta-learning. Our theory is more closely related to the subspace structure. recent empirical work that tries to understand various in- herently non-convex properties of learning-to-learn. Most Finally, we revisit existing upper bounds for the convex notably, Arnold et al. (2019) hypothesize and show some • case. We show that our lower bound does not contradict experimental evidence that the success of gradient-based these upper bounds, since their task similarity measure meta-learning requires non-convexity, a view theoretically of average parameter distance is large in our case. We supported by our work. Meanwhile, Raghu et al. (2019) complement this observation by proving that the existing demonstrate that the success of the popular MAML algo- bounds are tight, in some sense, and going beyond them rithm (Finn et al., 2017) is likely due to its ability to learn will require additional structural assumptions. good data-representations rather than adapt quickly. Our upper bound, in fact, supports these empirical findings by Paper organization: We discuss related work in Sec- showing that Reptile can indeed learn a good representation tion 2. Section 3 sets up notation for the rest of the pa- at the penultimate layer, and learning just a final linear layer per, formalizes initialization-based meta-learning methods for a new task will reduce sample complexity for a new task. and defines the subspace meta-learning instance that we Our results draw upon work motivated by understanding are interested in. The lower bound for linear regression is deep learning that analyzes trajectories and implicit regu- stated in Section 4, while the corresponding upper bounds larization in deep linear neural networks (Saxe et al., 2014; for non-convex meta-learning with two-layer linear network Gunasekar et al., 2018; Saxe et al., 2019; Gidel et al., 2019). is provided in Section 5. While all proofs are provided in The analysis of solutions found by gradient flow in deep lin- the appendix, we give a sketch of the proofs for the up- ear networks by (Saxe et al., 2014; Gidel et al., 2019) form per bounds in Section 6 to highlight the key steps in the a core component of our analysis. In this vein, Lampinen & trajectory-based analysis and discuss why such an analy- Ganguli (2019) recently studied the dynamics of deep linear sis is important. A discussion about tightness of existing networks in the context of transfer learning and show that convex-case upper bounds can be found in Section 7. A Sample Complexity Separation between Non-Convex and Convex Meta-Learning jointly learning linear representations using two tasks will 3.3. Initialization-based meta-learning yield smaller error on each one than individual task learn- We focus on a popular approach in meta-learning that uses ing. However their guarantees are not for an unseen task training tasks to learn an initialization of the model parame- drawn from a distribution, but only for two given tasks, and ters. This initialization is fed into a pre-specified gradient- crucially not for gradient-based meta-learning methods. based algorithm that updates model parameters starting from this initialization by using samples from a new task. We 3. Meta-Learning Setup refer to these methods as initialization-based meta-learning methods and they are restricted to return within-task al- 3.1. Notations gorithms of the form Alg( )=GD-Alg( ; ✓ ), where · · init Let [N] denote the set 1,...,N . We use x for vectors, GD-Alg runs some gradient-based algorithm starting from { } M for matrices, I for d dimensional identity matrix and the initialization ✓ ⇥ on an objective function that d init 2 0 for the all-zero vector in d dimensions. is used to depends on the input training set S. For example, we can de- d k·k denote the ` norm. For a function ` : X Y Z, we note the algorithm of gradient descent as GD(S; ✓ ), that 2 ⇥ ! init use `(x, ):Y Z to denote a function of the second runs gradient descent to convergence on the empirical risk · ! argument when the first argument is set to x. For a finite `S by starting from the initialization ✓init. The definitions set S, x S denotes sampling uniformly from S.We of the various initialization-based meta-learning and within- ⇠ also need the ReLU function [x] = x x 0 . For a task algorithms that we analyze are in sections 4.1 and 5.1. + { } sequence a ,...,a , we use a for j i to denote the In the subsequent sections, we will concretely define the { 1 T } i:j set a ,...,a . distribution of tasks µ and the meta-learning algorithms we { i j} are interested in. 3.2. Task distribution and excess risk 3.4. Meta-learning a subspace We are interested in regression tasks of the following form 2 For meta-learning to be meaningful, the tasks must share `⇢(✓):= E (f(x,✓) y) (1) (x,y) ⇢ some common structure. Here we focus on a structure that ⇠ assumes the existence of a low-dimensional representation where we abuse notation and use ⇢ to denote a task as well as of the data that suffices to solve all the tasks, specifically, a its associated data distribution. The input x is a vector in Rd linear representation. To capture this idea, we construct a and y is real-valued scalar. The function f : Rd ⇥ R is ⇥ ! simple but instructive meta-learning instance. a regressor of choice, e.g. a linear function or a deep neural d network, that is parametrized by ✓ ⇥. Often one only We are interested in tasks ⇢w for w R , where the distri- 2n 2 has access to samples S = (xi,yi) from the unknown bution is defined as follows { }i=1 distribution ⇢, and the empirical risk is defined as 2 2 (x,y) ⇢w : x (0,Id),y (w>x, ) (3) `S(✓)= E (f(x,✓) y) (2) ⇠ ⇠N ⇠N (x,y) S ⇠ The target y for x is a linear function of x plus a zero-mean 1 While various formalizations for meta-learning exist, we Gaussian noise added to it. The meta-learning instance present one that is most convenient for the presentation of µw is defined as uniform distribution over two tasks ⇢w ⇤ d ⇤ this work. In our meta-learning setting, we assume that there and ⇢ w for a fixed but unknown vector w R . Note ⇤ ⇤ 2 that for every point x Rd, only the projection of x onto is an underlying unknown distribution µ over tasks. Given 2 T ⇢ ,...,⇢ µ the direction of w is necessary to solve all tasks in µw . access to a training tasks 1 T sampled from , the ⇤ ⇤ goal of a meta-learner Meta is to learn some underlying Thus the hope is that a meta-learning algorithm picks up on structure that relates the tasks in µ and output a within-task this structure and learns to project data onto this subspace algorithm Alg = Meta(⇢1:T ) that can be used to solve a for sample efficiency on a new task. The average task risk new task sampled from µ. To solve a new task ⇢ µ by and excess risk for an algorithm Alg can then be written as ⇠ using training set S from ⇢, the meta-learned algorithm Alg n(Alg,µw )= E E `sw (Alg(S)) outputs parameters Alg(S) ⇥. The average risk of an L ⇤ s 1 S ⇢n ⇤ 2 ⇠ sw algorithm that uses n samples from a new task is ⇠{± } ⇤ n(Alg,µw )= n(Alg,µw ) ⇤(µw ) (4) E ⇤ L ⇤ L ⇤ n(Alg,µ)= E E `⇢(Alg(S)) L ⇢ µ S ⇢n ⇠ ⇠ In the subsequent sections, we describe the convex setting We define the excess risk of Alg as (Alg,µ)= of linear regression and the equally expressive non-convex En n(Alg,µ) ⇤(µ), where ⇤(µ)= E inf `⇢(✓) is the setting of a two-layer linear network regressor. Our main L L L ⇢ µ ✓ ⇥ ⇠ 2 1 minimum achievable risk by the class ⇥ with complete We can extend all results to y = w>x + ⇠, where ⇠ is inde- knowledge of the distribution µ. pendent of x, just has 0 mean and variance 2. A Sample Complexity Separation between Non-Convex and Convex Meta-Learning result shows that while no meta-learning algorithm can learn 4.2. Lower bounds a meaningful initialization for a gradient-based within-task We use the definition of excess risk from Equation 4 algorithm in the convex setting, standard meta-learning algo- En and formally define sample complexity for a meta-learned rithms like Reptile on a two-layer linear network can in fact within-task algorithm below learn to project the data on the one-dimensional subspace and thus reduce the sample complexity for a new task from Definition 4.1 (Sample complexity). The minimum num- ⌦(d) to (1). ber of samples needed from a new task for a within-task O algorithm Alg to have excess risk smaller than ✏ is
4. Convex Meta-Learning Lower Bound n✏(Alg,µw )=min n N : n(Alg,µw ) ✏ (7) ⇤ { 2 E ⇤ } In this section, we use a regression function f that is linear We will proceed to show a lower bound for all meta-learning in x to solve the meta-learning instance µw . We have d ⇤ d algorithms that return an initialization to be used by algo- ⇥=R , the parameters are ✓ = w, w R and the ⌘,t0 2 rithms GD and GD described in the previous subsec- regressor is f(x, w) := w x. Using the definition of the step reg > tion. We assume that w = = r to make the noise of distribution in Equation 3, for s 1 we get k ⇤k 2{± } the same order as the signal and for simplicity of presenta- 2 2 2 `sw (w)= E (w>x y) = w sw + (5) tion. The lower bounds in more generality can be found in ⇤ (x,y) k ⇤k ⇢sw Appendix B.1. ⇠ ⇤ r2 Theorem 4.1. Suppose w = = r and ✏ 0, 2 . 2 k ⇤k 2 Thus we have ⇤(µw )= E inf `sw (w)= . d L ⇤ s 1 w d ⇤ For every initialization w0 R , the number of samples⇣ ⌘ 2R 2 ⇠{± } needed to have ✏ excess risk on a new task is 2 4.1. Within-task algorithms dr min n✏(GDreg( ; w0),µw )=⌦ 0 · ⇤ ✏ As described in Section 3.2, we consider within-task algo- ✓ ◆ 2 rithms that are based on gradient descent. A meta-learner ⌘,t0 dr d min n✏(GDstep( ; w0),µw )=⌦ is allowed to learn an initialization w0 R that is used ⌘>0,t0 N+ · ⇤ ✏ 2 2 ✓ ◆ as a starting point to run a gradient-based algorithm on a The lower bound is strong due of the following new task. We will show lower bounds for the following Remark 4.1. algorithms The bound holds even if the meta learner has seen in- ⌘,t0 • GDstep(S; w0) - GD for t0 steps: finitely many tasks sampled from µ and has access to the Runs gradient descent with learning rate ⌘ for t0 steps on population loss for each task. `S (defined in Equation 2). Starting from w0, follow the Even regularization techniques like explicit `2- dynamics below and return wt0 . • regularization or early stopping cannot benefit from a w = w ⌘ ` (w ) meta-learned initialization. t+1 t rw S t 2 ✏ r /2 The condition is not restrictive since a trivial GD (S; w0) - -regularized GD: • 2 reg learner that outputs 0d for all task has error exactly r . Runs gradient descent with vanishingly small learning rate (gradient flow) to convergence on ` The lower bound also holds for more general meta- S, • learning instances. For any distribution over the optimal 2 2 classifiers that satisfies [w]=0 , the lower bound `S, (w)= E (w>x y) + w (6) E d (x,y) S 2 k k ⇢w µ ⇠ will depends on the variance⇠ of the optimal regressors ⇥ ⇤ 2 Starting from w0, follow the dynamics below, return w . across tasks, i.e. r = E [ w ]. The zero-mean prop- 1 ⇢w µ k k ⇠ dw erty is similar to many standard meta-learning bench- t = ` (w ) dt rw S, t marks where average prediction across tasks for an input is independent of the input. This holds for sine regression In the next section we will provide lower bounds on the (Finn et al., 2017) due to symmetry around zero and for excess risk for all initialization-based meta-learning algo- Omniglot (Lake et al., 2017) and Mini-ImageNet (Ravi & rithms that return initializations for the above algorithms. Larochelle, 2017) due to label-shuffling. Note that some of these algorithms have been used in prior ⌘,t0 work; most notably, GDstep is the base-learner used by This demonstrates that the convex formulation does not do MAML (Finn et al., 2017), so our convex-case lower-bounds justice to the practical efficacy of such algorithms. We pro- hold directly for any initialization it might learn. vide the proof of this result and even tigher lower bounds in A Sample Complexity Separation between Non-Convex and Convex Meta-Learning the appendix. The proofs are based on finding a closed-form algorithm is sequentially updated as (Ai+1, wi+1)=(1 ⌘,t0 expression for the solutions found by GDreg and GDstep and ⌧)(Ai, wi)+⌧GDpop(`⇢i+1 , (Ai, wi)) for some 0 <⌧<1. showing that, in fact, no initialization has better excess risk At the end of T tasks, return AT . than the trivial initialization of 0 . d On encountering a new task, Reptile slowly interpolates between the current initialization and the solution for the 5. Non-Convex Meta-Learning Upper Bound new task obtained by running gradient descent on it starting from the current initialization. As mentioned earlier, this We now use a two layer linear network as the regressor method has enjoyed empirical success (McMahan et al., f. The parameters in this case are ✓ =(A, w), A d d d 2 2017). The second algorithm of interest is reminiscent to R ⇥ , w R . The regressor f is then defined as 2 multi-task representation learning. f(x, (A, w)) := w>Ax. As before,
2 2 RepLearn(⇢1:T , (A0, w0,1:T )) - Representation learn- `sw ((A, w)) = A>w sw + (8) ⇤ k ⇤k ing: Starting from (A0, w0,1:T ), run gradient flow on 2 the following objective function: rep(A, w1:T )= Again it is easy to see that ⇤(µw )= . We now describe L L ⇤ T the within-task algorithms of interest and the initialization- 1 T `⇢i (A, wi), return A at the end. based meta-algorithms for which we show guarantees. i=1 1 P dAt 5.1. Within-task and meta-learning algorithms = A rep(At, wt,1:T ) dt r L We are interested in the following within-task algorithms. dwt,i = (A , w ),i [T ] dt rwt,i Lrep t t,1:T 2 GDpop(⇢;(A0, w0)) - Population GD: Runs gradient descent with vanishingly small learning This is a standard objective for multi-task representation rate (gradient flow) to convergence on `⇢. Starting from learning used in prior work, occasionally equipped with a (A0, w0), follow the dynamics below, return (A , w ). regularization term for w1:T . For our analysis we do not 1 1 need an explicit regularizer, just like Saxe et al. (2014) and dA dw t = ` ((A , w )); t = ` ((A , w )) Gidel et al. (2019). dt rA ⇢ t t dt rw ⇢ t t