Generative Interventions for Causal Learning
Total Page:16
File Type:pdf, Size:1020Kb
Generative Interventions for Causal Learning Chengzhi Mao1 Augustine Cha1* Amogh Gupta1* Hao Wang2 Junfeng Yang1 Carl Vondrick1 1Columbia University, 2Rutgers University {mcz, junfeng, vondrick}@cs.columbia.edu, {ac4612, ag4202}@columbia.edu, [email protected] Abstract We introduce a framework for learning robust visual representations that generalize to new viewpoints, back- grounds, and scene contexts. Discriminative models of- ten learn naturally occurring spurious correlations, which cause them to fail on images outside of the training distri- Ladle Television Shovel bution. In this paper, we show that we can steer generative Figure 1. Top predictions from a state-of-the-art ImageNet classi- models to manufacture interventions on features caused by fier [21]. The model uses spurious correlations (scene contexts, confounding factors. Experiments, visualizations, and the- viewpoints, and backgrounds), leading to incorrect predictions.1 oretical results show this method learns robust representa- In this paper, we introduce a method to learn causal visual fea- tions more consistent with the underlying causal relation- tures that improve robustness of visual recognition models. The ships. Our approach improves performance on multiple predictions of our model are in Figure 7. datasets demanding out-of-distribution generalization, and the ObjectNet [6] dataset removes several common spuri- we demonstrate state-of-the-art performance generalizing ous correlations from the test set, causing the performance from ImageNet to ObjectNet dataset. of state-of-the-art models to deteriorate by 40% compared 1. Introduction to the ImageNet validation set. A promising direction for fortifying visual recognition Visual recognition today is governed by empirical risk is to learn causal representations (see [43] for an excellent minimization (ERM), which bounds the generalization er- overview). If representations are able to identify the causal ror when the training and testing distributions match [47]. mechanism between the image features and the category la- When training sets cover all factors of variation, such as bels, then robust generalization is possible. While the tradi- background context or camera viewpoints, discriminative tional approach to establish causality is through randomized models learn invariances and predict object category labels control trials or interventions, natural images are passively with the right cause [33]. However, the visual world is vast collected, preventing the use of such procedures. and naturally open. Collecting a representative, balanced dataset is difficult and, in some cases, impossible because This paper introduces a framework for learning causal the world can unpredictably change after learning. visual representations with natural images. Our approach Directly optimizing the empirical risk is prone to learn- is based on the observation that generative models quan- ing unstable spurious correlations that do not respect the tify nuisance variables [23, 26], such as viewpoint or back- underlying causal structure [11, 8, 24, 44, 4, 35]. Figure 1 ground. We present a causal graph that models both ro- illustrates the issue succinctly. In natural images, the ob- bust features and spurious features during image recogni- ject of interest and the scene context have confounding fac- tion. Crucially, our formulation shows how to learn causal tors, creating spurious correlations. For example, ladle (the features by steering generative models to perform interven- object of interest) often has a hand holding it (the scene tions on realistic images, simulating manipulations to the context), but there is no causal relation between them. Sev- camera and scene that remove spurious correlations. As our eral studies have exposed this challenge by demonstrating approach is model-agnostic, we are able to learn robust rep- substantial performance degradation when the confounding resentations for any state-of-the-art computer vision model. bias no longer holds at testing time [41, 19]. For example, Our empirical and theoretical results show that our ap- *Equal Contribution. Order by coin flip. 1The correct categories are clearly a broom, a tray, and a shoe. 3947 Rotation Scale Background (Principal Component 5) (Principal Component 10) (Principal Component 12) Figure 2. Generative adversarial networks are steerable [23, 26], allowing us to manipulate images and construct interventions on nuisances. The transformations transfer across categories. Each column in the figure presents images with one consistent intervention direction. proach learns representations that regard causal structures. variant to the nuisances caused by the confounding factors While just sampling from generative models will replicate [8]. A large body of work studies how to acquire causal the same training set biases, steering the generative mod- effects from a combination of association levels and inter- els allows us to reduce the bias, which we show is critical vention levels [31, 11, 32]. Ideally, we can learn an invari- for performance. On ImageNet-C [22] benchmark, we sur- ant representation across different environments and nui- pass established methods by up to 12%, which shows that sances [8, 35] while maintaining the causal information [5]. our method helps discriminate based on the causal features. While structural risk minimization, such as regularization Our approach also demonstrates the state-of-the-art perfor- [27], can also promote a model’s causality, this paper fo- mance on the new ObjectNet dataset [6]. We obtain 39.3% cuses on training models under ERM [47]. top-1 accuracy with ResNet152 [21], which is over 9% gain Generative Models: Our work leverages recent ad- over the published ObjectNet benchmark [6] while increas- vances in deep generative models [20, 28, 25, 10, 40]. Deep ing accuracy on ImageNet and ImageNet-V2 [41]. We will generative models capture the joint distribution of the data, release code, data, and models. which complements discriminative models [25, 37, 38]. Prior work has explored adding data sampled from a deep 2. Related Work generator to the original training data to improve classifi- cation accuracy on ImageNet [39]. We denote it as GAN Data augmentation: Data augmentation often helps Augmentation in this paper. Other works improved clas- learn robust image classifiers. Most existing data augmen- sification accuracy under imbalanced and insufficient data tations use lower-level transformations [29, 46], such as ro- by oversampling through a deep generator [16, 17, 49, 7]. tate, contrast, brightness, and shear. Auto-data augmenta- However, sampling without intervention, the augmented tion [13, 52] uses reinforcement learning to optimize the data still follows the same training joint distribution, where combination of those lower-level transformations. Other unobserved confounding bias will continue to contaminate work, such as cutout [15] and mixup [51], develops new the generated data. Thus, the resulting models still fail to augmentation strategies towards improved generalization. generalize once the spurious correlations changed. Ideally, [34, 54, 19] explored style transfer to augment the training we want to generate data independent of the spurious corre- data, however, the transformations for training are limited lations while holding the object’s causal features fixed. to texture and color change. Adversarial training, where images are augmented by adding adversarial noise, can also Recent works analyzing deep generative models show train robust models [50]. However, both adversarial train- that different variations, such as viewpoints and back- ing [50] and auto augmentation [13, 52] introduce orders of ground, are automatically learned [26, 23]. We leverage magnitude of computational overhead. In addition, none deep generative models for constructing interventions in re- of the above methods can do high-level transformations alistic visual data. Our work randomizes a large class of such as changing the viewpoint or background [6], while steerable variations, which shifts the observed data distri- our generative interventions can. Our method fundamen- bution to be independent of the confounding bias further. tally differs from prior data augmentation methods because Our approach tends to manipulate high-level transforma- it learns a robust model by estimating the causal effects tions orthogonal to traditional data augmentation strategies via generative interventions. Our method not only elimi- [12, 46, 29], and we obtain additional performance gains by nates spurious correlations more than data augmentations, combining them. but also theoretically produces a tighter causal effect bound. Domain Adaptation: Our goal is to train robust mod- Causal Models: Causal image classifiers generalize els that generalize to unforeseen data. Accessing the test well despite environmental changes because they are in- data distribution, even unlabeled, could lead to overfitting 3948 8 chance (Ideal) C C Imagenet Original C Mixup Zi 6 Stylized ImageNet Z Z ZU AutoAug GAN Augmentation GenInt (ours) X Y X Y 4 X Y 2 F F F Relative Correlation (a) (b) (c) 0 10 100 1000 Number of Categories Figure 4. Causal graph for image classification. Gray variables are Figure 3. Do unwanted correlations exist between the nuisance observed. (a) F is the variable that generates the object features. factors (e.g. backgrounds, viewpoint) and labels on ImageNet? The unobserved confounder C causes both the background fea- We measure correlation