Learning Where to Sample in Structured Prediction

Tianlin Shi Jacob Steinhardt Percy Liang Tsinghua University Stanford University Stanford University [email protected] [email protected] [email protected]

Abstract the rest. Often, a large number of local moves is re- quired. In structured prediction, most inference al- One source of inefficiency in Gibbs sampling is that it gorithms allocate a homogeneous amount of dedicates a homogeneous amount of inference to each computation to all parts of the output, which part of the output. However, in practice, the diffi- can be wasteful when different parts vary culty and inferential demands of each part is hetero- widely in terms of difficulty. In this paper, geneous. For example, in rendering computer graph- we propose a heterogeneous approach that ics, paths of light passing through highly reflective or dynamically allocates computation to the dif- glossy regions deserve more sampling [Veach, 1997]; ferent parts. Given a pre-trained model, we in named-entity recognition [McCallum and Li, 2003], tune its inference algorithm (a sampler) to most tokens clearly do not contain entities and there- increase test-time throughput. The inference fore should be allocated less computation. Attempts algorithm is parametrized by a meta-model have been made to capture the nature of heterogene- and trained via , where ity in such settings. For example, Elidan et al.[2006] actions correspond to sampling candidate schedule updates in asynchronous belief propagation parts of the output, and rewards are log- based on the information residuals of the messages. likelihood improvements. The meta-model Chechetka and Guestrin[2010] focus the computation is based on a set of domain-general meta- of belief propagation based on the specific variables features capturing the progress of the sam- being queried. Other work has focused on building pler. We test our approach on five datasets cascades of coarse-to-fine models, where simple models and show that it attains the same accuracy filter out unnecessary parts of the output and reduce as Gibbs sampling but is 2 to 5 times faster. the computational burden for complex models [Viola and Jones, 2001, Weiss and Taskar, 2010]. We propose a framework that constructs heteroge- 1 Introduction neous sampling algorithms using reinforcement learn- ing (RL). We start with a collection of transition ker- For many structured prediction problems, the output nels, each of which proposes a modification to part of contains many interdependent variables, resulting in the output (in this paper, we use transition kernels exponentially large output spaces. These properties derived from Gibbs sampling). At each step, our pro- make exact inference intractable for models with high cedure chooses which transition kernel to apply based treewidth [Koller et al., 2007], and thus we must rely on cues from the input and the history of proposed on approximations such as variational inference and outputs. By optimizing this procedure, we fine-tune Markov Chain Monte Carlo (MCMC). A key charac- inference to exploit the specific heterogeneity in the teristic of many such approximate inference algorithms task of interest, thus saving overall computation at is that they iteratively modify only a local part of the test time. output structure at a small computational cost. For The main challenge is to find signals that consistently example, Gibbs sampling [Brooks et al., 2011] updates provide useful cues (meta-features) for directing the fo- the output by sampling one variable conditioned on cus of inference across a wide variety of tasks. More- over, it is important that these signals are cheap to Appearing in Proceedings of the 18th International Con- ference on Artificial Intelligence and Statistics (AISTATS) compute relative to the cost of generating a sample, 2015, San Diego, CA, USA. JMLR: W&CP volume 38. as otherwise the meta-features themselves become the Copyright 2015 by the authors. bottleneck of the algorithm. In this paper, we provide Learning Where to Sample in Structured Prediction general principles for constructing such meta-features, Algorithm 1 Template for a heterogeneous sampler based on reasoning about staleness and discord of vari- 1: Initialize yr0s „ P0pyq for some initializing P0p¨q. ables in the output. We cache these ideas out as a 2: for t “ 1 to TMpxq do collection of five concrete meta-features, which empir- 3: select transition kernel Aj for some 1 ďj ď m ically yield good predictions across tasks as diverse as 4: sample yrts „ Ajp¨ | yrt ´ 1sq part-of-speech (POS) tagging, named-entity recogni- 5: end for tion (NER), handwriting recognition, color inpainting, 6: output y “ yrT s and scene decomposition. In summary, our contributions are: at a higher level, which test instances to sample. • The conceptual idea of learning to sample: we present a learning framework based on RL, and 2.1 Framework discuss meta-features that leverage heterogeneity. We now formalize the intuition from the previous ex- • The practical value of the framework: given a ample. Assume our pre-trained model specifies a dis- pre-trained model, we can effectively optimize the tribution ppy | xq. On a set of test inputs X “ test-time throughput of its Gibbs sampler. pxp1q,..., xpnqq, we would like to infer the outputs Y “ pyp1q,..., ypnqq using some inference algorithm 2 Heterogeneous Sampling M. To simplify notation, we will focus on a single in- stance px, yq, though our final algorithm considers all Before we formalize our framework for heterogeneous test instances jointly. Notice that we can reduce from sampling, let’s consider a motivating example. the multiple-instance case to the single-instance case by just concatenating all the instances into a single x I think now is the right time instance. pass 1: y PRP VBP RB VBZ DT NN NN pass 2: y PRP VBP RB VBZ DT JJ NN We further assume that a single output y “ py , . . . , y q is represented by m variables. For in- Table 1: A POS tagging example. Outputs are 1 m stance, in POS tagging, y (j “ 1, . . . , m) is the part- recorded after each sweep of Gibbs sampling. Only j of-speech of the j-th word in the sentence x. We are the ambiguous token “right”(NN: noun, JJ: adjective) given a collection of transition kernels which target the needs more inference at the second sweep. distribution ppy | xq. For Gibbs sampling, we have the 1 kernels tAjpy | yq : j “ 1, . . . , mu, where the transi- 1 Suppose our task is part-of-speech (POS) tagging, tion Aj samples yj conditioned on all other variables 1 where the input x P X is a sentence and the out- y j, and leaves y j equal to y j. put y P Y is the sequence of POS tags for the words. Algorithm1 describes the form of samplers we con- An example is shown in Table1. Suppose that the sider. A sampler generates a sequence of outputs full model is a chain-structured conditional random yr1s, yr2s,... by iteratively selecting a variable index field (CRF) with unigram potentials on each tag and jrts and sampling yrt ` 1s „ Ajrtsp¨ | yrtsq. Rather bigram potentials between adjacent tags. Exact infer- than applying the transition kernels in a fixed order, ence algorithms exist for this model, but for illustrative our samplers select the transition Ajrts to apply based purposes we use cyclic Gibbs sampling, which samples on the input x together with the sampling history. from the conditional distribution of each tag in cyclic The total number of Markov transitions TMpxq made order from left to right. by M characterizes its computational cost on input x. The example in Table1 shows at least two sweeps of How do we choose which transition kernel to apply? cyclic Gibbs sampling are required, because it is hard A natural objective is to maximize the expected log- to know whether “right” is an adjective or a noun un- likelihood under the model ppy | xq of the sampler til the tag for the following word “time” is sampled. output Mpxq: However, the second pass wastes computation by sam- pling other tags that are mostly determined at the first max EqMpy|xqrlog ppy | xqs (1) pass. This inspires the following inference strategy: M ˚ s.t.: TMpxq ď T , pass 1 sample the tags for each word. pass 2 sample the tag for “right”. where qMpy | xq is the probability that M outputs y and T ˚ is the computation budget. Equation (1) says In general, it is desirable to have the inference algo- that we want qM to place as much probability mass rithm itself figure out which locations to sample, and as possible on values of y that have high probability Tianlin Shi, Jacob Steinhardt, Percy Liang

˚ under p, subject to a constraint T on the amount tion, we use Qpst,Ajq to predict the cumulative of computation. Note that if T ˚ “ 8, the optimal reward over a shorter time horizon H ! T ˚. solution would be the posterior mode. • The reward over time H also depends on the con- Solving this optimization problem at test time is in- text of yj. By subtracting the contextual part of feasible, so we will instead optimize M on a training the reward, we can hope to isolate the contribu- set, and then deploy the resulting M at test time. tion of action Aj. Thus, we use Qpst,Ajq to model the difference in reward from taking a transition 3 Reinforcement Learning of Aj, relative to the baseline of making no transi- Heterogeneous Samplers tion.

We would like to optimize the objective in (1), but Formally, we learn Q using sample backup [Sutton and searching over all samplers M and evaluating the ex- Barto, 1998, Chapter 9.5]. We start with some state- pectation in (1) are both difficult. We use reinforce- action sequence ps0, a0, s1, a1, . . . , sT ˚ q sampled from 1 ment learning (RL) to find an approximate solution. a fixed exploration policy π . Then, for each index t pt “ 0,...,T ˚ ´ 1q in the sequence, we generate a c Reduction to RL. Recall y 0 , y 1 ,... is the se- rollout starting from initial state s1 “ st`1: for i “ r s r s c c quence of outputs generated by our sampler, where 1, 2,...,H, we generate action ai “ arg maxa Qpsi , aq c c and state si 1 using ai , and define the utility due to yrt ` 1s „ Ajrtsp¨ | yrtsq. To cast our prob- ` taking at: lem into the RL framework, let the state st “ pyr0s ..., yrts, jr0s, . . . , jrt ´ 1sq be the entire history H of samples, and the action a A refers to the c c c c t “ jrts U “ Rpst, at, st`1q ` Rpsi , ai , si`1q. (4) transition kernel that produces yrt ` 1s from yrts. We i“1 ÿ let the reward be the improvement in log-probability: b Next, consider starting from state s1 “ st and not b b taking at, and letting ai “ arg maxa Qpsi , aq. The b b Rpst, at, st`1q “ log ppyrt`1s| xq´log ppyrts | xq. (2) resulting states and actions si , ai define the following utility: Let S be the space of states and A be the space of H actions as defined above. The goal of RL is to find a b b b b U “ Rpsi , ai , si`1q. (5) policy π : S Ñ A to maximize the expected cumulative i“1 ÿ reward ErRT ˚ s, where To model the Q-function, we use a single-layer neural T ´1 network with one hidden node [Tesauro, 1995]: RT “ Rpst, at, st`1q. (3) t“0 ÿ Qps, aq “ w σpα ¨ φps, aqq ` b, (6) Clearly, the total reward R is equal to log ppyrT s | T where σp¨q is the logistic function, φps, aq P L are xq ´ log ppyr0s | xq; since yr0s is independent of the R meta-features and α P L, w P , b P are the meta- particular policy, maximizing cumulative reward is R R R parameters; we write θ “ pw, b, αq. equivalent to maximizing the original objective in (1). Although reward only depends on the last output yrts We update θ with a temporal difference update [Sut- of a state st, we will allow the policy to depend on the ton and Barto, 1998] based on our rollout: full history encapsulated in st. θ Ð θ ` ηt ˚ dt, (7)

Learning algorithm. Our algorithm learns an L`2 where ηt P R are step sizes, “ ˚ ” is element-wise action-value function Qpst,Ajq that predicts the value multiplication and the vector dt is of using Aj in state st. Standard reinforcement learn- ing methods, such as Q-learning [Watkins and Dayan, c b dt “ U ´ U ´ Qpst, atq ∇θQpst, atq. (8) 1992] and SARSA [Rummery and Niranjan, 1994], do not work well for our purpose. The issue is that they ´ ¯ For choosing ηt, we use AdaGrad from Duchi et al. attempt to learn a function Qpst,Ajq that models the [2010]: expected cumulative future reward if we take action η ηt “ , (9) A in state s . In our setting this cumulative reward t j t δ ` d ˚ d is hard to predict for the following two reasons: i“0 i i where η is the meta learningb ř rate and δ is a smoothing • It is difficult to estimate how far the current sam- parameter (we use η “ 1 and δ “ 10´4 in experi- ple is from the global optima. As an approxima- ments). Learning Where to Sample in Structured Prediction

We use the cyclic Gibbs sampler as the fixed explo- Opnjq to compute, and they should not be computed ration policy π1 to generate the initial states for the more often than being sampled. rollouts. The entire training procedure is shown in In order to satisfy this criterion, an idea that we have Algorithm2: found consistently useful is stale values. For example, suppose we would like to use the entropy of ppy | Algorithm 2 Learning a heterogeneous sampler. j y jq as a meta-feature. Computing it would be just 1: Input Dataset X, transition kernels Aj, number as expensive as sampling from A . Instead, we keep 1 j of epochs E, cyclic Gibbs policy π , time horizon track of a stale version of conditional entropy. Every H. time we sample from Aj, we compute the entropy of 2: Initialize y „ P0pyq. Set s0 “ pyr0sq. yj according to the sampling distribution and store it 3: for epoch = 1,..., E do in memory. Then the stale conditional entropy of y ˚ j 4: for t “ 0,...,T ´ 1 do is defined as the current entropy in memory. As the 1 5: Get action at “ π pstq. stale conditional entropy is a meta-feature, we leave it 6: Sample st`1 from at. up to the learning algorithm to determine how much 7: Extract φpst, atq. to trust it. 8: Estimate gradient using (8) 9: Update meta-parameters θ via (7). Reasoning about staleness is valuable in another way 10: end for as well: we want to know how different the Markov 11: end for blanket of yj is relative to the last time it was sampled. If it is very different, this tells us two things: first, any stored quantities such as entropy that we have Test time. At test time, we apply Algorithm1. In computed are probably out of date. More importantly, particular, we maintain a priority queue over locations the conditional distribution ppyj | y jq is probably j, where priorities are the Q-values. very different than when we last sampled from Aj, so it is probably a good idea to sample from Aj again. • During the select step, compute In addition to staleness, another important notion is arg max Q s ,A for state s , popping off Aj p t jq t discord. If at least one neighbor of yj is inconsistent the maximum element from the priority queue. with the current value of yj, then it is probably worth sampling from A . • After the sample step, re-compute the meta- j features of the actions that depend on jrts and update the corresponding Q-values in the priority Templates. Based on the ideas of staleness and dis- queue. cord, we introduce the following meta-feature tem- plates: In terms of computational complexity, select takes vary: the number of variables in the Markov blanket Oplog mq, and re-scoring takes OpLn log mq, where j of y that have changed since the last time we sam- m is the number of variables, and n is the number j j pled from A . The more neighbors of y that have of neighbors of variable y . The complexity of meta- j j j changed, the more stale y could be. This meta- feature computation will be discussed in Section4. j feature is computed as follows: upon sampling yj, we set varypyjq “ 0, and increment vary of its neighbors 4 Meta-Features by 1. So the computational complexity of vary is only Opnjq. The effectiveness of our framework hinges on the exis- nb: discord with neighbors. We define meta-features tence of good meta-features for the Q-function in (6). nb-y-y1 (y, y1 1,...,K ) which are binary indicators Our goal is to develop general meta-features that ex- P t u to track the occurrences of each of the possible value hibit strong predictive power across many datasets. In pairs y, y1 for y and one of its neighbors. The intu- this section, we offer a few guiding principles, which p q j ition is that certain pairs of values are very unlikely to culminates in a set of five meta-feature templates. occur as part of a legitimate output, and thus repre- sent a discordance that requires taking more samples; Principles. One necessary criterion is that comput- the nb meta-features allow us to learn these pairs from ing the meta-features should be computationally cheap data. Although the total number of nb meta-features relative to the rest of the inference algorithm. With- is K2, the computational complexity is still Opn q due out loss of generality, we assume each output variable j to sparsity. yj takes one of K values and has nj neighbors. Then such meta-features would take, for example, OpKq or cond-ent: the conditional entropy of yj at the last Tianlin Shi, Jacob Steinhardt, Percy Liang time it was sampled. Variables with high conditional Words Japan coach Shu Kamo said : ' ' The Syrian own goal proved lucky for us Truth B-LOC O B-PER I-PER O O O O O B-MISC O O O O O O entropy have a high degree of uncertainty and can ben- 1 I-ORG O B-PER I-PER O O O O O B-LOC O O O O O O efit from being sampled further. 2 B-LOC O B-PER I-PER O O O O O B-MISC O O O O O O 3 B-LOC O B-PER I-PER O O O O O B-MISC O O O O O O However, as noted earlier, keeping track of an up-to- (a) Cyclic Gibbs sampler on NER-f4 date conditional entropy would be too computationally expensive. We therefore use a stale version based on Words Japan coach Shu Kamo said : ' ' The Syrian own goal proved lucky for us Truth B-LOC O B-PER I-PER O O O O O B-MISC O O O O O O the last time that the variable yj in question was sam- 1 I-ORG O B-PER I-PER O O O O O B-LOC O O O O O O pled by Aj. The computational overhead of cond-ent 2 B-LOC O B-PER I-PER O O O O O B-MISC O O O O O O 3 B-LOC O B-PER I-PER O O O O O B-MISC O O O O O O is OpKq, since we can just compute the entropy based (b) HeteroSampler on NER-f4 on ppyj | y jq, which already needs to be computed in order to sample from Aj. Figure 1: Visualization of computational resource allo- unigram-ent: Sometimes, we can fit a simpler un- cation on a test example from NER-f4. Each row is igram model to the training dataset, where all the a snapshot of the sample after k “ 1, 2, 3 sweeps. A variables yj are independent (given the input). In darker color means that more cumulative samples have been taken. For HeteroSampler, one sweep corresponds such cases, unigram entropy of ppyjq can be used as an over-estimate of the degree of ambiguity for a given to making m transitions, where m is the number of vari- ables. HeteroSampler has learned to sample harder parts variable [Bishop, 2006]. If the unigram entropy is very of the instance, such as ambiguous tokens. low, the variable is probably not worth sampling. To capture this, we use the indicator function for the uni- gram entropy being below some threshold (10´4 in our POS/NER Tagging. The POS tagging dataset experiments). comes from the standard Wall Street Journal (WSJ) section of the Penn Treebank, and the NER tag- sp: number of times yj has been sampled thus far. A potential problem with all of the above meta-features ging dataset is taken from the 2003 CoNLL Shared is that they might overly explore possibilities for the Task. We trained an CRF model with fea- same variable. So we need some way to reason about tures between each token and its corresponding tag the fact that at some point sampling the same variable (i.e. features rgpxiq, yis with feature extractor gp¨q), more is unlikely to lead to improvements. We do this and higher-order features between/among tags (i.e. by keeping track of the number of times a variable ryi, yi`1, ..., yi`qs)[Liang et al., 2008]. We refer to q as has already been sampled: once a variable has been the factor size, and call the corresponding tasks POS- sampled too many times, it is unlikely that sampling fq and NER-fq. The feature extraction functions gp¨q it further will be fruitful. include prefixes and suffixes of the word (up to length 4), lowercase word, word signature (e.g. McDiarmid into AA and AaAaaaaaa, banana into a and aaaaaa) 5 Experiments and an indicator of the word’s being capitalized. The CRF is trained using AdaGrad [Duchi et al., 2010] with In this section, we provide empirical evaluation of our 5 passes over the training set, using Gibbs sampling for method, which we call HeteroSampler. inference with 8 total sweeps over each instance and 5 sweeps as burn-in. The performance is evaluated via tag accuracy for POS and F1 score for NER. 5.1 Datasets Handwriting Recognition. We use the handwrit- To evaluated our method on five tasks: part- ing recognition dataset from Weiss and Taskar[2010]. of-speech tagging (POS), named-entity recognition The data were originally collected by Kassel[1995], (NER), handwriting recognition, color inpainting, and and contain 6877 handwritten words from 150 subjects scene decomposition. with 55 distinct words. Each instance of the dataset is The general setup is as follows. First, we use RL to a word, which is a sequence of characters. Associated learn the parameters of the meta-model on the training with each output character is a corresponding 16 ˆ 8 dataset. We then run Algorithm1 on the test set. input binary optical image. The dataset is split into Unless otherwise stated, in all experiments we use E “ training and test set, where the training set had 6251 3 training epochs, step size η “ 1, and time horizon words and the test set had 626 words. H “ 1. In addition to using cyclic Gibbs to generate Similar to POS/NER tagging, the baseline algo- a base policy for training, we also compare to cyclic rithm is an CRF. We have a feature for each (pixel Gibbs at test time. value, location, character) triple, as well as higher- We evaluated on the following datasets: order n-gram potentials between consecutive charac- Learning Where to Sample in Structured Prediction

(a) NER (factor size 2) (b) NER (factor size 4) (c) POS (factor size 4) 0.80 0.80 0.97

0.78 0.78 0.96

0.76 0.76 0.95 Figure 2: Accuracy vs. 0.74 0.74 0.94 F1 score F1 score Accuracy number of transitions 0.72 0.93 0.72 HeteroSampler HeteroSampler HeteroSampler across several tasks. cyclic Gibbs cyclic Gibbs cyclic Gibbs HeteroSampler con- 0.70 0.70 0.92 0 20 40 60 80 100 120 0 10 20 30 40 50 60 70 80 0 20 40 60 80 100 120 Average Number of Transitions Average Number of Transitions Average Number of Transitions verges much faster than cyclic Gibbs sampling. (d) OCR (factor size 4) (e) Color Inpainting (f) Scene Decomposition On the color inpaint- 0.95 890 −6

) ing task, dynamically 2 880 0 −8

0.90 1 choosing the order of x

( −10 870 y

0.85 t sampling also allows i

l −12 i

b 860 the algorithm to find a a −14 0.80 b o Accuracy

r 850 better local optimum.

P −16

Log Probability

0.75 g

o HeteroSampler HeteroSampler −18 840 HeteroSampler L cyclic Gibbs cyclic Gibbs cyclic Gibbs 0.70 −20 830 0 20 40 60 80 100 120 140 0 2 4 6 8 10 12 14 16 0 200 400 600 800 1000 Average Number of Transitions Average Number of Transitions (x 105 ) Average Number of Transitions ters. The training scheme of the full model is similar parts of the task such as names. to POS/NER, except that we use 16 Gibbs sampling sweeps in total with 5 burn-in sweeps. The results are evaluated via character-wise accuracy. 5.3 Performance under Different Budgets Color Inpainting. The three-class color inpainting task is borrowed from ?. The input is a corrupted To measure the performance under different budgets, color image in a circular domain, and the target im- we gradually increase the total number of transitions age is an equipartition of the circle using three colors. at test time for the HeteroSampler and the overall We use a pre-trained model from the OpenGM bench- number of sweeps for the Gibbs baseline. The number mark [Kappes et al., 2013]. The baseline is Gibbs sam- of transitions for training are held fixed. pling with 100 sweeps over the instance. There are two instances in this dataset and we use one to train the Figure2 plots the performance versus average num- HeteroSampler and the other to test it. Performance ber of transitions per instance. As we see, given is evaluated based on the log-probability of the output. the same budget, HeteroSampler achieves equal or better performance for all tasks. In addition, Scene Decomposition. The scene decomposition HeteroSampler reaches the ceiling accuracy 2 to 5 dataset is obtained from the source in Gould et al. times faster regardless of the problem domain. For the [2009]. The goal is to segment a natural image into Color Inpainting problem, which is the most challeng- eight semantic categories, such as grass and sky. We ing of the tasks, HeteroSampler also achieves better use the subset of 715 images included in the OpenGM end performance when it converges. This is due to toolkit [Kappes et al., 2013], for which an existing the fact that, by optimizing the order of sampling, the is publicly available. The graphical meta-algorithm is able to find a better local optimum. model is a superpixel factor graph, and each super- Next, we justify measuring computational cost in pixel has 773 feature dimensions. Among the 715 im- terms of number of samples. To do this, we mea- ages, we randomly pick 358 instances for training, and sured the overhead of computing the policy relative the rest are used for testing. The baseline Gibbs sam- to the cost of sampling. As long as this overhead pling uses 16 sweeps over each image, and we train the is low, computing samples is the bottleneck in the HeteroSampler in the same way as Color Inpainting. base algorithm, and so number of samples is a reason- Performance is again evaluated via log-probability. able measure of computational cost. Figure3 shows how many wall-clock seconds were spent computing 5.2 Visualization of Resource Allocation the HeteroSampler policy for each dataset. For most tasks, sampling involves computing all of the features Figure1 visualizes the allocation of sampling op- of the full model, while computing the policy uses only erations on an NER instance. While cyclic Gibbs a few meta-features and therefore has negligible cost. sampling uniformly distributes its computational re- The exception is color inpainting, where the full model sources, HeteroSampler is able to focus on harder has only a few features. Tianlin Shi, Jacob Steinhardt, Percy Liang

(a) NER (factor size 2) (b) NER (factor size 4) input KANSAS CITY AT OAKLAND 180 180 immediate B-LOC I-LOC O B-LOC 160 Policy 160 Policy Overall Overall 140 140 cumulative B-ORG I-ORG O B-LOC 120 120 100 100 80 80 60 60 Table 2: Cumulative rewards are often helpful when 40 40 Wall-clock Seconds Wall-clock Seconds there is high correlation between variables. In this ex- 20 20 0 0 40 60 80 100 120 140 160 180 200 40 60 80 100 120 140 160 180 200 ample, “KANSAS CITY ” is initially labeled as an lo- Average Number of Transitions Average Number of Transitions cation. Two coordinated actions are needed to change (c) POS (factor size 4) (d) OCR (factor size 4) it to an organization and improve log-likelihood. A 900 250 800 Policy Policy Overall Overall meta-model trained with immediate rewards would not 700 200 600 recognize the value of sampling “KANSAS” alone. 150 500 400 100 300

200 50 Wall-clock Seconds Wall-clock Seconds 100 diction of cumulative rewards, we intentionally add an 0 0 100 200 300 400 500 600 700 800 900 0 50 100 150 200 250 Average Number of Transitions Average Number of Transitions oracle meta-feature, which is the immediate reward

(e) Color Inpainting (f) Scene Decomposition of sampling. Figures4(b) and (c) visualize the weights 80 250 of some meta-features for H “ 0 and H “ 1 respec- Policy 70 Policy Overall Overall 60 200 tively. As expected, when learned with immediate re-

50 150 wards, all weights concentrates on the oracle meta- 40 30 100 feature. The learned meta-model does not encourage 20 50 Wall-clock Seconds Wall-clock Seconds exploration and therefore may omit positions that have 10 0 0 cumulative reward. When trained with cumulative re- 0 50 100 150 200 250 300 350 100 150 200 250 300 Average Number of Transitions Average Number of Transitions wards, the meta-model also distributes some weight to Figure 3: Wall-clock time vs. average number of transi- cond-ent, which leads to better exploration and bet- tions across different datasets. This measures the overhead ter performance. of policy evaluation. The red line “sampling” shows the time spent without policy evaluation, while the blue line “policy” shows the time spent on policy evaluation. 5.5 Meta-feature Ablation Analysis

To evaluate the contribute of individual meta-features 5.4 Cumulative Rewards and to understand their role in predicting reward, we did a meta-feature ablation analysis. With one meta- We would like to verify two facts: first, training with feature removed at a time, we run the meta-algorithm cumulative rewards is useful relative to just using im- and produce a convergence curve, shown in Figure5 mediate reward; second, our meta-features can predict for NER-f2 and POS-f4. All meta-features play an im- the cumulative rewards. First, Figure4(a) shows ac- portant role. sp is the most important; without it, curacy vs. number of transitions with H “ 0 (im- we would repeatedly sample variables with high un- mediate rewards) and H “ 1 (cumulative rewards) certainty. on NER-f4. As we can see, with cumulative rewards, HeteroSampler performs better. Table2 provides an intuitive explanation. 6 Related Work and Discussion To see which meta-features are contributing to the pre- At a high-level, our approach is about fine-tuning the inference algorithm of a pre-trained model to make

1.0 0.80 1.0 more effective use of computational resources at test 0.8 0.78 0.8 time. Specifically, we use reinforcement learning to 0.76 0.6 0.6 F1 0.74 0.4 0.4 train a sampler that operates on variables heteroge- H=0 0.72 0.2 H=1 0.2 neously based on the promise of likelihood improve- 0.0 0.70 0.0 10 20 30 40 50 60 70 −0.2 Average Number of Transations oracle bias cond-ent sp oracle bias cond-ent sp ments. We demonstrated substantial speed improve- (a) (b) (c) ments on several structured prediction tasks. The idea of treating structured prediction as a se- Figure 4: Effect of training with cumulative rewards. (a) quential decision-making problem has been explored Convergence curve of the inference algorithm trained with by SEARN [Daume et al., 2009] and DAGGER [Ross one-step look-ahead (H “ 1) vs. immediate rewards (H “ et al., 2011a]. Both train a multiclass classifier to build 0). (b) Weights of meta-features when trained with H “ 0 plus the oracle meta-feature. (c) Weights of the meta- up a structured output incrementally in a fixed order. features when trained with H “ 1 plus the oracle meta- Similar ideas have been applied in dependency parsing feature. [Goldberg and Nivre, 2013]. Learning Where to Sample in Structured Prediction

0.80

0.78 0.96

0.76 all 0.95 all all\unigram-ent all\unigram-ent

F1 0.74 all\nb 0.94 all\nb 0.72 all\vary Accuracy all\vary all\cond-ent 0.93 all\cond-ent 0.70 all\sp all\sp 0.68 0.92 14 16 18 20 22 24 25 30 35 40 45 50 55 60 Average Transations Per Example Average Transations Per Example (a) Meta-feature ablation on NER (factor size is 2) (b) Meta-feature ablation on POS (factor size is 4)

Figure 5: Meta-feature ablation study on various datasets. (a) shows the convergence curves of HeteroSampler with one meta-feature removed on (a) NER-f2 and (b) POS-f4. F is the entire meta-feature set, and “z” denotes excluding a meta-feature. We see that each meta-feature matters for at least some of the tasks, and sp is the most important.

To obtain speedups, it is beneficial to learn the or- les the local optima problem in structured prediction der in which the structured output is constructed; by using RL to train policies that could select fruitful this flexibility is the cornerstone of our work. For downward jumps, which is the same issue that our use example, Goldberg and Elhadad[2010] proposed an of cumulative rewards attempts to address. approach that learns to construct a dependency tree More generally, the goal of speeding up inference at by adding “easy” arcs first. More generally, Jiang test time is quite established by now. Viola and Jones et al.[2012] maintains a priority queue over partially [2001] used a sequence of models from simple to com- constructed hypotheses for constituency parsing and plex for face detection, at each successive stage prun- learns to choose which one to process first. While the ing out unlikely locations in the image. Weiss and aforementioned work builds up outputs incrementally, Taskar[2010] trained a sequence of Markov models of our heterogeneous sampler makes modifications to full increasing order, each successive stage pruning out un- outputs, which can be more flexible. likely local configurations. Other work also operate in the space of full outputs. As feature extraction is often the performance bottle- For example, Doppa et al.[2014b,a] perform several neck, it is a promising place to look for speed improve- steps of local search around a baseline prediction. ments. Weiss and Taskar[2013] used RL to train poli- Zhang et al.[2014] performed greedy hill-climbing cies that adaptively determine the value of information from multiple random starting points for dependency of each feature at test time. For dependency parsing, parsing. Ross et al.[2011b] used DAGGER to learn He et al.[2013] considers a sequence of increasingly message-passing inference algorithms. However, un- complex features, and uses DAGGER to learn which like our method, none of these papers deal with the arcs to commit to before adding more features. issue of determining which locations are useful to op- erate on without explicitly evaluating the model score Our work is superficially related to the work on adap- for each candidate modification. We use lightweight tive MCMC [Andrieu and Thoms, 2008], but the goals meta-features for this purpose. are quite different. Adaptive MCMC samplers attempt to preserve the stationary distribution, while our ap- Some methods use a fixed strategy to prioritize infer- proach seeks to directly maximize log-likelihood within ence in a fixed model. For example, residual belief a fixed number of time steps. propagation [Elidan et al., 2006] selects the message between two variables that has changed the most from As a final remark, in this paper, we only focused on the the previous iteration. In cases where we are inter- issue of “where to sample”, but the general framework, ested in a particular query variable, Chechetka and which merely learns which transition kernels to apply, Guestrin[2010] prioritizes messages based on impor- could also be applied to determine “how to sample” too tance to the query. Wick and McCallum[2011] imple- by supplying a richer family of transition kernels—for ments the same intuition in the context of MCMC. example, ones based on models with different feature sets or blocked samplers. This opens up a vast set of SampleRank [Wick et al., 2011] also performs learning possibilities for finer-grained adaptivity. in the context of sampling, but is complementary to our work: SampleRank fixes a sampling strategy and Acknowledgements. This work was supported by trains the underlying model, whereas we fix the un- the Fannie & John Hertz Foundation for the second derlying model and train the sampling strategy using author and the Microsoft Research Faculty Fellowship domain-general meta-features. Wick et al.[2009] tack- for the third author. Tianlin Shi, Jacob Steinhardt, Percy Liang

References H. He, H. Daume, and J. Eisner. Dynamic fea- ture selection for dependency parsing. In Empirical C. Andrieu and J. Thoms. A tutorial on adaptive Methods in Natural Language Processing (EMNLP), MCMC. Statistics and Computing, 18(4):343–373, 2013. 2008. J. Jiang, A. Teichert, J. Eisner, and H. Daume. C. M. Bishop. Pattern recognition and machine learn- Learned prioritization for trading off accuracy and ing. Springer New York, 2006. speed. In Advances in Neural Information Process- ing Systems (NIPS), 2012. S. Brooks, A. Gelman, G. Jones, and X. Meng. Hand- book of Markov Chain Monte Carlo. CRC Press, J. H. Kappes, B. Andres, F. A. Hamprecht, C. Schnorr, 2011. S. Nowozin, D. Batra, S. Kim, B. X. Kausler, J. Lell- mann, N. Komodakis, and C. Rother. A compara- Antonin Chambolle, Daniel Cremers, and Thomas tive study of modern inference techniques for dis- Pock. A convex approach to minimal partitions. crete energy minimization problems. In Computer SIAM Journal on Imaging Sciences, 5(4):1113– Vision and Pattern Recognition (CVPR), 2013. 1158, 2012. R.H. Kassel. A comparison of approaches to on-line A. Chechetka and C. Guestrin. Focused belief propa- handwritten character recognition. PhD thesis, Mas- gation for query-specific inference. In Artificial In- sachusetts Institute of Technology, 1995. telligence and Statistics (AISTATS), 2010. D. Koller, N. Friedman, L. Getoor, and B. Taskar. H. Daume, J. Langford, and D. Marcu. Search-based Graphical models in a nutshell. Statistical Relational structured prediction. , 75:297– Learning, page 13, 2007. 325, 2009. P. Liang, H. Daume, and D. Klein. Structure compila- tion: Trading structure for features. In International J.R. Doppa, A. Fern, and P. Tadepalli. Hc-search: A Conference on Machine Learning (ICML), 2008. learning framework for search-based structured pre- diction. Journal of Artificial Intelligence Research, A. McCallum and W. Li. for named entity recogni- 50:403–439, 2014a. tion with conditional random fields, feature induc- tion and web-enhanced lexicons. In Proceedings of J.R. Doppa, A. Fern, and P. Tadepalli. Structured pre- the seventh conference on Natural language learning diction via output space search. Journal of Machine at HLT-NAACL 2003-Volume 4, pages 188–191. As- Learning Research, 15:1317–1350, 2014b. sociation for Computational Linguistics, 2003. J. Duchi, E. Hazan, and Y. Singer. Adaptive sub- S. Ross, G. Gordon, and A. Bagnell. A reduction of gradient methods for online learning and stochas- imitation learning and structured prediction to no- tic optimization. In Conference on Learning Theory regret online learning. In Artificial Intelligence and (COLT), 2010. Statistics (AISTATS), 2011a.

G. Elidan, I. McGraw, and D. Koller. Residual belief S. Ross, D. Munoz, M. Hebert, and J. A. Bagnell. propagation: Informed scheduling for asynchronous Learning message-passing inference machines for message passing. In Uncertainty in Artificial Intel- structured prediction. In and Pat- ligence (UAI), 2006. tern Recognition (CVPR), pages 2737–2744, 2011b.

Y. Goldberg and M. Elhadad. An efficient algorithm G.A. Rummery and M. Niranjan. Online Q-learning for easy-first non-directional dependency parsing. In using connectionist systems. University of Cam- Association for Computational Linguistics (ACL), bridge, Department of Engineering, 1994. pages 742–750, 2010. R.S. Sutton and A.G. Barto. Introduction to reinforce- ment learning. MIT Press, 1998. Y. Goldberg and J. Nivre. Training deterministic parsers with non-deterministic oracles. Transac- G. Tesauro. Temporal difference learning and td- tions of the Association for Computational Linguis- gammon. Communications of the ACM, 38(3):58– tics (TACL), 1, 2013. 68, 1995. S. Gould, R. Fulton, and D. Koller. Decomposing a E. Veach. Robust Monte Carlo methods for light trans- scene into geometric and semantically consistent re- port simulation. PhD thesis, Stanford University, gions. In ICCV, 2009. 1997. Learning Where to Sample in Structured Prediction

P. Viola and M. Jones. Rapid object detection using a boosted cascade of simple features. In Computer Vision and Pattern Recognition (CVPR), 2001. C. Watkins and P. Dayan. Q-learning. Machine learn- ing, 8(3-4):279–292, 1992.

D. Weiss and B. Taskar. Structured prediction cas- cades. In Artificial Intelligence and Statistics (AIS- TATS), 2010. D. Weiss and B. Taskar. Learning adaptive value of information for structured prediction. In Advances in Neural Information Processing Systems (NIPS), 2013. M. Wick, K. Rohanimanesh, S. Singh, and a. A. Mc- Callum. Training factor graphs with reinforcement learning for efficient map inference. In Advances in Neural Information Processing Systems (NIPS), 2009. M. Wick, K. Rohanimanesh, and K. Bellare. Sampler- ank: Training factor graphs with atomic gradients. In International Conference on Machine Learning (ICML), 2011. M. L. Wick and A. McCallum. Query-aware MCMC. In Advances in Neural Information Processing Sys- tems (NIPS), pages 2564–2572, 2011.

Y. Zhang, T. Lei, R. Barzilay, and T. Jaakkola. Greed is good if randomized: New inference for depen- dency parsing. In Empirical Methods in Natural Language Processing (EMNLP), 2014.