Sequence-Level Knowledge Distillation

Sequence-Level Knowledge Distillation

Sequence-Level Knowledge Distillation Yoon Kim Alexander M. Rush [email protected] [email protected] School of Engineering and Applied Sciences Harvard University Cambridge, MA, USA Abstract proaches. NMT systems directly model the proba- bility of the next word in the target sentence sim- Neural machine translation (NMT) offers a ply by conditioning a recurrent neural network on novel alternative formulation of translation the source sentence and previously generated target that is potentially simpler than statistical ap- words. proaches. However to reach competitive per- formance, NMT models need to be exceed- While both simple and surprisingly accurate, ingly large. In this paper we consider applying NMT systems typically need to have very high ca- knowledge distillation approaches (Bucila et pacity in order to perform well: Sutskever et al. al., 2006; Hinton et al., 2015) that have proven (2014) used a 4-layer LSTM with 1000 hidden units successful for reducing the size of neural mod- per layer (herein 4 1000) and Zhou et al. (2016) ob- els in other domains to the problem of NMT. × tained state-of-the-art results on English French We demonstrate that standard knowledge dis- → tillation applied to word-level prediction can with a 16-layer LSTM with 512 units per layer. The be effective for NMT, and also introduce two sheer size of the models requires cutting-edge hard- novel sequence-level versions of knowledge ware for training and makes using the models on distillation that further improve performance, standard setups very challenging. and somewhat surprisingly, seem to elimi- This issue of excessively large networks has been nate the need for beam search (even when ap- plied on the original teacher model). Our best observed in several other domains, with much fo- student model runs 10 times faster than its cus on fully-connected and convolutional networks state-of-the-art teacher with little loss in per- for multi-class classification. Researchers have par- formance. It is also significantly better than ticularly noted that large networks seem to be nec- a baseline model trained without knowledge essary for training, but learn redundant representa- distillation: by 4.2/1.7 BLEU with greedy de- tions in the process (Denil et al., 2013). Therefore coding/beam search. Applying weight prun- compressing deep models into smaller networks has ing on top of knowledge distillation results in been an active area of research. As deep learning a student model that has 13 fewer param- × eters than the original teacher model, with a systems obtain better results on NLP tasks, compres- decrease of 0.4 BLEU. sion also becomes an important practical issue with applications such as running deep learning models for speech and translation locally on cell phones. 1 Introduction Existing compression methods generally fall into Neural machine translation (NMT) (Kalchbrenner two categories: (1) pruning and (2) knowledge dis- and Blunsom, 2013; Cho et al., 2014; Sutskever et tillation. Pruning methods (LeCun et al., 1990; He al., 2014; Bahdanau et al., 2015) is a deep learning- et al., 2014; Han et al., 2016), zero-out weights or based method for translation that has recently shown entire neurons based on an importance criterion: Le- promising results as an alternative to statistical ap- Cun et al. (1990) use (a diagonal approximation to) 1317 Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing, pages 1317–1327, Austin, Texas, November 1-5, 2016. c 2016 Association for Computational Linguistics the Hessian to identify weights whose removal min- 2 Background imally impacts the objective function, while Han et al. (2016) remove weights based on threshold- 2.1 Sequence-to-Sequence with Attention ing their absolute values. Knowledge distillation ap- Let s = [s1, . , sI ] and t = [t1, . , tJ ] be (random proaches (Bucila et al., 2006; Ba and Caruana, 2014; variable sequences representing) the source/target Hinton et al., 2015) learn a smaller student network sentence, with I and J respectively being the to mimic the original teacher network by minimiz- source/target lengths. Machine translation involves ing the loss (typically L2 or cross-entropy) between finding the most probable target sentence given the the student and teacher output. source: In this work, we investigate knowledge distilla- argmax p(t s) t | tion in the context of neural machine translation. We ∈T note that NMT differs from previous work which has where is the set of all possible sequences. NMT T mainly explored non-recurrent models in the multi- models parameterize p(t s) with an encoder neural | class prediction setting. For NMT, while the model network which reads the source sentence and a de- is trained on multi-class prediction at the word-level, coder neural network which produces a distribution it is tasked with predicting complete sequence out- over the target sentence (one word at a time) given puts conditioned on previous decisions. With this the source. We employ the attentional architecture difference in mind, we experiment with standard from Luong et al. (2015), which achieved state-of- knowledge distillation for NMT and also propose the-art results on English German translation.3 → two new versions of the approach that attempt to ap- proximately match the sequence-level (as opposed 2.2 Knowledge Distillation to word-level) distribution of the teacher network. Knowledge distillation describes a class of methods This sequence-level approximation leads to a sim- for training a smaller student network to perform ple training procedure wherein the student network better by learning from a larger teacher network is trained on a newly generated dataset that is the (in addition to learning from the training data set). result of running beam search with the teacher net- We generally assume that the teacher has previously work. been trained, and that we are estimating parame- We run experiments to compress a large state-of- ters for the student. Knowledge distillation suggests the-art 4 1000 LSTM model, and find that with training by matching the student’s predictions to the × sequence-level knowledge distillation we are able to teacher’s predictions. For classification this usually learn a 2 500 LSTM that roughly matches the per- means matching the probabilities either via L on × 2 formance of the full system. We see similar results the log scale (Ba and Caruana, 2014) or by cross- compressing a 2 500 model down to 2 100 on entropy (Li et al., 2014; Hinton et al., 2015). × × a smaller data set. Furthermore, we observe that Concretely, assume we are learning a multi-class our proposed approach has other benefits, such as classifier over a data set of examples of the form not requiring any beam search at test-time. As a re- (x, y) with possible classes . The usual training V sult we are able to perform greedy decoding on the criteria is to minimize NLL for each example from 2 500 model 10 times faster than beam search on the training data, × the 4 1000 model with comparable performance. × Our student models can even be run efficiently on |V| 1 1 a standard smartphone. Finally, we apply weight NLL(θ) = y = k log p(y = k x; θ) L − { } | pruning on top of the student network to obtain a Xk=1 model that has 13 fewer parameters than the origi- × where 1 is the indicator function and p the nal teacher model. We have released all the code for {·} distribution from our model (parameterized by θ). the models described in this paper.2 3Specifically, we use the global-general attention model 1https://github.com/harvardnlp/nmt-android with the input-feeding approach. We refer the reader to the orig- 2https://github.com/harvardnlp/seq2seq-attn inal paper for further details. 1318 Figure 1: Overview of the different knowledge distillation approaches. In word-level knowledge distillation (left) cross-entropy is minimized between the student/teacher distributions (yellow) for each word in the actual target sequence (ECD), as well as between the student distribution and the degenerate data distribution, which has all of its probabilitiy mass on one word (black). In sequence-level knowledge distillation (center) the student network is trained on the output from beam search of the teacher network that had the highest score (ACF). In sequence-level interpolation (right) the student is trained on the output from beam search of the teacher network that had the highest sim with the target sequence (ECE). This objective can be seen as minimizing the cross- Since this new objective has no direct term for the entropy between the degenerate data distribution training data, it is common practice to interpolate (which has all of its probability mass on one class) between the two losses, and the model distribution p(y x; θ). | (θ; θ ) = (1 α) (θ) + α (θ; θ ) In knowledge distillation, we assume access to L T − LNLL LKD T a learned teacher distribution q(y x; θT ), possibly where α is mixture parameter combining the one-hot | trained over the same data set. Instead of minimiz- distribution and the teacher distribution. ing cross-entropy with the observed data, we instead minimize the cross-entropy with the teacher’s prob- 3 Knowledge Distillation for NMT ability distribution, The large sizes of neural machine translation sys- tems make them an ideal candidate for knowledge |V| KD(θ; θT ) = q(y = k x; θT ) distillation approaches. In this section we explore L − | × three different ways this technique can be applied to Xk=1 log p(y = k x; θ) NMT. | 3.1 Word-Level Knowledge Distillation where θT parameterizes the teacher distribution and remains fixed. Note the cross-entropy setup is iden- NMT systems are trained directly to minimize word tical, but the target distribution is no longer a sparse NLL, , at each position.

View Full Text

Details

  • File Type
    pdf
  • Upload Time
    -
  • Content Languages
    English
  • Upload User
    Anonymous/Not logged-in
  • File Pages
    11 Page
  • File Size
    -

Download

Channel Download Status
Express Download Enable

Copyright

We respect the copyrights and intellectual property rights of all users. All uploaded documents are either original works of the uploader or authorized works of the rightful owners.

  • Not to be reproduced or distributed without explicit permission.
  • Not used for commercial purposes outside of approved use cases.
  • Not used to infringe on the rights of the original creators.
  • If you believe any content infringes your copyright, please contact us immediately.

Support

For help with questions, suggestions, or problems, please contact us