Learning Where to Sample in Structured Prediction
Total Page:16
File Type:pdf, Size:1020Kb
Learning Where to Sample in Structured Prediction Tianlin Shi Jacob Steinhardt Percy Liang Tsinghua University Stanford University Stanford University [email protected] [email protected] [email protected] Abstract the rest. Often, a large number of local moves is re- quired. In structured prediction, most inference al- One source of inefficiency in Gibbs sampling is that it gorithms allocate a homogeneous amount of dedicates a homogeneous amount of inference to each computation to all parts of the output, which part of the output. However, in practice, the diffi- can be wasteful when different parts vary culty and inferential demands of each part is hetero- widely in terms of difficulty. In this paper, geneous. For example, in rendering computer graph- we propose a heterogeneous approach that ics, paths of light passing through highly reflective or dynamically allocates computation to the dif- glossy regions deserve more sampling [Veach, 1997]; ferent parts. Given a pre-trained model, we in named-entity recognition [McCallum and Li, 2003], tune its inference algorithm (a sampler) to most tokens clearly do not contain entities and there- increase test-time throughput. The inference fore should be allocated less computation. Attempts algorithm is parametrized by a meta-model have been made to capture the nature of heterogene- and trained via reinforcement learning, where ity in such settings. For example, Elidan et al.[2006] actions correspond to sampling candidate schedule updates in asynchronous belief propagation parts of the output, and rewards are log- based on the information residuals of the messages. likelihood improvements. The meta-model Chechetka and Guestrin[2010] focus the computation is based on a set of domain-general meta- of belief propagation based on the specific variables features capturing the progress of the sam- being queried. Other work has focused on building pler. We test our approach on five datasets cascades of coarse-to-fine models, where simple models and show that it attains the same accuracy filter out unnecessary parts of the output and reduce as Gibbs sampling but is 2 to 5 times faster. the computational burden for complex models [Viola and Jones, 2001, Weiss and Taskar, 2010]. We propose a framework that constructs heteroge- 1 Introduction neous sampling algorithms using reinforcement learn- ing (RL). We start with a collection of transition ker- For many structured prediction problems, the output nels, each of which proposes a modification to part of contains many interdependent variables, resulting in the output (in this paper, we use transition kernels exponentially large output spaces. These properties derived from Gibbs sampling). At each step, our pro- make exact inference intractable for models with high cedure chooses which transition kernel to apply based treewidth [Koller et al., 2007], and thus we must rely on cues from the input and the history of proposed on approximations such as variational inference and outputs. By optimizing this procedure, we fine-tune Markov Chain Monte Carlo (MCMC). A key charac- inference to exploit the specific heterogeneity in the teristic of many such approximate inference algorithms task of interest, thus saving overall computation at is that they iteratively modify only a local part of the test time. output structure at a small computational cost. For The main challenge is to find signals that consistently example, Gibbs sampling [Brooks et al., 2011] updates provide useful cues (meta-features) for directing the fo- the output by sampling one variable conditioned on cus of inference across a wide variety of tasks. More- over, it is important that these signals are cheap to Appearing in Proceedings of the 18th International Con- ference on Artificial Intelligence and Statistics (AISTATS) compute relative to the cost of generating a sample, 2015, San Diego, CA, USA. JMLR: W&CP volume 38. as otherwise the meta-features themselves become the Copyright 2015 by the authors. bottleneck of the algorithm. In this paper, we provide Learning Where to Sample in Structured Prediction general principles for constructing such meta-features, Algorithm 1 Template for a heterogeneous sampler based on reasoning about staleness and discord of vari- 1: Initialize yr0s „ P0pyq for some initializing P0p¨q. ables in the output. We cache these ideas out as a 2: for t “ 1 to TMpxq do collection of five concrete meta-features, which empir- 3: select transition kernel Aj for some 1 ¤j ¤ m ically yield good predictions across tasks as diverse as 4: sample yrts „ Ajp¨ | yrt ´ 1sq part-of-speech (POS) tagging, named-entity recogni- 5: end for tion (NER), handwriting recognition, color inpainting, 6: output y “ yrT s and scene decomposition. In summary, our contributions are: at a higher level, which test instances to sample. • The conceptual idea of learning to sample: we present a learning framework based on RL, and 2.1 Framework discuss meta-features that leverage heterogeneity. We now formalize the intuition from the previous ex- • The practical value of the framework: given a ample. Assume our pre-trained model specifies a dis- pre-trained model, we can effectively optimize the tribution ppy | xq. On a set of test inputs X “ test-time throughput of its Gibbs sampler. pxp1q;:::; xpnqq, we would like to infer the outputs Y “ pyp1q;:::; ypnqq using some inference algorithm 2 Heterogeneous Sampling M. To simplify notation, we will focus on a single in- stance px; yq, though our final algorithm considers all Before we formalize our framework for heterogeneous test instances jointly. Notice that we can reduce from sampling, let's consider a motivating example. the multiple-instance case to the single-instance case by just concatenating all the instances into a single x I think now is the right time instance. pass 1: y PRP VBP RB VBZ DT NN NN pass 2: y PRP VBP RB VBZ DT JJ NN We further assume that a single output y “ py ; : : : ; y q is represented by m variables. For in- Table 1: A POS tagging example. Outputs are 1 m stance, in POS tagging, y (j “ 1; : : : ; m) is the part- recorded after each sweep of Gibbs sampling. Only j of-speech of the j-th word in the sentence x. We are the ambiguous token \right"(NN: noun, JJ: adjective) given a collection of transition kernels which target the needs more inference at the second sweep. distribution ppy | xq. For Gibbs sampling, we have the 1 kernels tAjpy | yq : j “ 1; : : : ; mu, where the transi- 1 Suppose our task is part-of-speech (POS) tagging, tion Aj samples yj conditioned on all other variables 1 where the input x P X is a sentence and the out- y j, and leaves y j equal to y j. put y P Y is the sequence of POS tags for the words. Algorithm1 describes the form of samplers we con- An example is shown in Table1. Suppose that the sider. A sampler generates a sequence of outputs full model is a chain-structured conditional random yr1s; yr2s;::: by iteratively selecting a variable index field (CRF) with unigram potentials on each tag and jrts and sampling yrt ` 1s „ Ajrtsp¨ | yrtsq. Rather bigram potentials between adjacent tags. Exact infer- than applying the transition kernels in a fixed order, ence algorithms exist for this model, but for illustrative our samplers select the transition Ajrts to apply based purposes we use cyclic Gibbs sampling, which samples on the input x together with the sampling history. from the conditional distribution of each tag in cyclic The total number of Markov transitions TMpxq made order from left to right. by M characterizes its computational cost on input x. The example in Table1 shows at least two sweeps of How do we choose which transition kernel to apply? cyclic Gibbs sampling are required, because it is hard A natural objective is to maximize the expected log- to know whether \right" is an adjective or a noun un- likelihood under the model ppy | xq of the sampler til the tag for the following word \time" is sampled. output Mpxq: However, the second pass wastes computation by sam- pling other tags that are mostly determined at the first max EqMpy|xqrlog ppy | xqs (1) pass. This inspires the following inference strategy: M ˚ s.t.: TMpxq ¤ T ; pass 1 sample the tags for each word. pass 2 sample the tag for \right". where qMpy | xq is the probability that M outputs y and T ˚ is the computation budget. Equation (1) says In general, it is desirable to have the inference algo- that we want qM to place as much probability mass rithm itself figure out which locations to sample, and as possible on values of y that have high probability Tianlin Shi, Jacob Steinhardt, Percy Liang ˚ under p, subject to a constraint T on the amount tion, we use Qpst;Ajq to predict the cumulative of computation. Note that if T ˚ “ 8, the optimal reward over a shorter time horizon H ! T ˚. solution would be the posterior mode. • The reward over time H also depends on the con- Solving this optimization problem at test time is in- text of yj. By subtracting the contextual part of feasible, so we will instead optimize M on a training the reward, we can hope to isolate the contribu- set, and then deploy the resulting M at test time. tion of action Aj. Thus, we use Qpst;Ajq to model the difference in reward from taking a transition 3 Reinforcement Learning of Aj, relative to the baseline of making no transi- Heterogeneous Samplers tion. We would like to optimize the objective in (1), but Formally, we learn Q using sample backup [Sutton and searching over all samplers M and evaluating the ex- Barto, 1998, Chapter 9.5].