Efficient and Robust Multi-Task Learning in the Brain with Modular Task Primitives
Total Page:16
File Type:pdf, Size:1020Kb
Efficient and robust multi-task learning in the brain with modular task primitives. Christian D. Márton Guillaume Lajoiey Friedman Brain Institute Dep. of Mathematics and Statistics Icahn School of Medicine at Mount Sinai Université de Montréal New York, NY 10029 MILA - Québec AI Institute [email protected] [email protected] Kanaka Rajany Friedman Brain Institute Icahn School of Medicine at Mount Sinai New York, NY 10029 [email protected] Abstract In a real-world setting biological agents do not have infinite resources to learn new things. It is thus useful to recycle previously acquired knowledge in a way that allows for faster, less resource-intensive acquisition of multiple new skills. Neural networks in the brain are likely not entirely re-trained with new tasks, but how they leverage existing computations to learn new tasks is not well understood. In this work, we study this question in artificial neural networks trained on commonly used neuroscience paradigms. Building on recent work from the multi-task learning literature, we propose two ingredients: (1) network modularity, and (2) learning task primitives. Together, these ingredients form inductive biases we call structural and functional, respectively. Using a corpus of nine different tasks, we show that a modular network endowed with task primitives allows for learning multiple tasks well while keeping parameter counts, and updates, low. We also show that the skills acquired with our approach are more robust to a broad range of perturbations compared to those acquired with other multi-task learning strategies. This work offers a new perspective on achieving efficient multi-task learning in the brain, and makes predictions for novel neuroscience experiments in which targeted perturbations are employed to explore solution spaces. arXiv:2105.14108v1 [cs.AI] 28 May 2021 1 Introduction Task knowledge can be stored in the collective dynamics of large populations of neurons in the brain [37, 48, 64, 46, 27, 68, 10, 45, 58, 34, 57, 2, 8, 56]. This body of work has offered insight into the computational mechanisms underlying the execution of multiple different tasks across various domains. In these studies, however, animals and models are over-trained on specific tasks, at the expense of long training duration and high energy expenditure. Models are optimized for one specific task, and unlike networks of neurons in the brain, are not meant to be re-used across tasks. In addition to being resource-efficient, it is useful that acquired task representations be resistant to failures and perturbations. Biological organisms are remarkably resistant to lesions affecting large extents of the neural circuit [3, 12], and to neural cell death with aging [18, 63]. Current approaches Preprint. Under review. do not take this into account; often, representations acquired naively are highly susceptible to perturbations [60]. In the multi-task setting, where multiple tasks are learnt simultaneously [64] or sequentially [15, 69, 28], the amount of weight updates performed grows quickly with the number of tasks as the entire model is essentially being re-trained with each new task. Training time and energy cost increase rapidly and can become prohibitive. To overcome this, it is conceivable that the brain has adopted particular inductive biases to speed up learning over an evolutionary timescale [48]. Inductive biases can help guide and accelerate future learning by providing useful prior knowledge. This becomes particularly useful when resources (energy, time, neurons) to acquire new skills are limited, and the space of tasks that need to be acquired is not infinite. In fact, biological agents often inhabit particular ecological niches [9] which offer very particular constraints on the kinds of tasks that are useful and need to be acquired over a lifetime. Under these circumstances, useful priors or inductive biases can help decrease the amount of necessary weight updates. One way in which such biases may be expressed in the brain is in terms of neural architecture or structure [48]. Neural networks in the brain may come with a pre-configured wiring diagram, encoded in the genome [67]. Such structural pre-configurations can put useful constraints on the learning problem. One way in which this is expressed in the brain is modularity. The brain exhibits specializations over various scales [55, 62, 21, 36, 35], with visual information being primarily represented in the visual cortices, auditory information in the auditory cortices, and decision-making tasks encoded in circuits in the prefrontal cortex [37, 10, 34]. Further subdivisions within those specialized modules may arise in the form of particular cell clusters, or groupings in a graph-theoretic sense, processing particular kinds of information (such as cells mapping particular visual features or aspects of a task). From a multi-task perspective, modularity can offer an efficient solution to the problem of task interference and catastrophic forgetting in continual learning [23, 40, 49], while increasing robustness to perturbations. Another way in which inductive biases may be expressed in the brain, beyond structure, is in terms of function: the types of neural dynamics instantiated in particular groups of neurons, or neural ensembles. Rather than learn separate models for separate tasks, as many of the studies cited above do, or re-train the same circuit many times on different tasks [64, 15], core computational principles could be learned once and re-used across tasks. There is evidence that the brain implements stereotyped dynamics which can be used across tasks [13, 51]. Ensembles of neurons may be divided into distinct groups based on the nature of the computations they perform, forming functional modules [42, 55, 5, 44, 35]. In the motor domain, for instance, neural activity was found to exhibit shared structure across tasks [19]. Different neural ensembles can be found to implement particular aspects of a task [29, 64, 10] and theoretical work suggests the number of different neural sub-populations influences the set of dynamics a neural network is able to perform [4]. Thus, a network that implements a form of task primitives, or basis-computations, could offer dynamics useful for training multiple tasks. Putting these different strands together, in this work we investigate the advantage of combining structural and functional inductive biases in a recurrent neural network (RNN) model of modular task primitives. Structural specialization in the form of modules can help avoid task interference and increase robustness to failure of particular neurons, while core functional specialization in the form of task primitives can help substantially lower the costs of training. We show how our model can be used to facilitate resource-efficient learning in a multi-task setting. We also compare our approach to other training paradigms, and demonstrate higher robustness to perturbations. Altogether, our work shows that combining structural and functional inductive biases in a recurrent network allows an agent to learn many tasks well, at lower cost and higher robustness compared to other approaches. This draws a path towards cheaper multi-task solutions, and offers new hypotheses for how multi-task learning may be achieved in the brain. 2 Background The brain is highly recurrently connected [30], and recurrent neural networks have been successfully used to model dynamics in the brain [37, 48, 64, 46, 27, 68, 10, 45, 58, 34, 57, 2, 8, 56]. Reservoir 2 computing approaches, such as those proposed in the context of echo-state networks (ESNs) and liquid state machines (LSMs) [41, 26, 32], avoid learning the recurrent weights of a network (only training output weights) and thereby achieve reduced training time and lower computational cost compared to recurrent neural networks (RNNs) whose weights are all trained in an end-to-end fashion. This can be advantageous: when full state dynamics are available for training, reservoir networks exhibit higher predictive performance and lower generalisation error at significantly lower training cost, compared to recurrent networks fully trained with gradient descent using backpropagation through time (BPTT); in the face of partial observability, however, RNNs trained with BPTT perform better [61]. Also, naive implementations of reservoir networks show high susceptibility to local perturbations, even when implemented in a modular way [60]. Module training inputs Modular RNN Tasks trained Modules performing task primitives 1 0.5 M1 Instructions Fixation 0 Autoencode 1 0.5 M2 T1 T2 T3 T4 Inputs Input 1 0 Memory 1 Input 2 0 -0.5 M3 Inputs Memory 2 -1 Fix Stim Go Recurrent Weights Output Weights 0 50 100 150 200 Time (msec) Fig 1. Schematic of fully disconnected recurrent network with modular task primitives. Left: inputs for modules M1-M3. Middle: recurrent weight matrix of the modular network, with separate partitions for modules M1-M3 along the diagonal. Right: different tasks (T1-T9) trained into different output weights, leveraging the same modular network. Meanwhile, previous work in continual learning has largely focused on feedforward architectures [23, 66, 1, 40, 20, 65, 53, 52, 68, 28, 69, 31, 14, 50]. The results obtained in feedforward systems, however, are not readily extensible to continual learning in dynamical systems, where recurrent interactions may lead to unexpected behavior in the evolution of neural responses over time [54, 15]. Previously, modular approaches have also largely been explored in feedforward systems [59, 16, 24]. Recently, a method for continual learning has been proposed for dynamical systems [15]. The authors have shown that their method, dynamical non-interference, outperforms other gradient-based approaches [69, 28] in key applications. All of these approaches, however, require the entire network to be re-trained with each new task and thus come with a high training cost. A network with functionally specialized modules could provide dynamics that are useful across tasks and thus help mitigate this problem. In the use of modules, our approach is similar to memory-based approaches to continual learning which have also been applied in recurrently connected architectures [22, 11].