Counterfactual Explanations for Graph Neural Networks
Total Page:16
File Type:pdf, Size:1020Kb
CF-GNNExplainer: Counterfactual Explanations for Graph Neural Networks Ana Lucic Maartje ter Hoeve Gabriele Tolomei University of Amsterdam University of Amsterdam Sapienza University of Rome Amsterdam, Netherlands Amsterdam, Netherlands Rome, Italy [email protected] [email protected] [email protected] Maarten de Rijke Fabrizio Silvestri University of Amsterdam Sapienza University of Rome Amsterdam, Netherlands Rome, Italy [email protected] [email protected] ABSTRACT that result in an alternative output response (i.e., prediction). If Given the increasing promise of Graph Neural Networks (GNNs) in the modifications recommended are also clearly actionable, this is real-world applications, several methods have been developed for referred to as achieving recourse [12, 28]. explaining their predictions. So far, these methods have primarily To motivate our problem, we consider an ML application for focused on generating subgraphs that are especially relevant for computational biology. Drug discovery is a task that involves gen- a particular prediction. However, such methods do not provide erating new molecules that can be used for medicinal purposes a clear opportunity for recourse: given a prediction, we want to [26, 33]. Given a candidate molecule, a GNN can predict if this understand how the prediction can be changed in order to achieve molecule has a certain property that would make it effective in a more desirable outcome. In this work, we propose a method for treating a particular disease [9, 19, 32]. If the GNN predicts it does generating counterfactual (CF) explanations for GNNs: the minimal not have this desirable property, CF explanations can help identify perturbation to the input (graph) data such that the prediction the minimal change one should make to this molecule, such that it changes. Using only edge deletions, we find that our method can has this desirable property. This could help us not only generate a generate CF explanations for the majority of instances across three new molecule with the desired property, but also understand what widely used datasets for GNN explanations, while removing less structures contribute to a molecule having this property. than 3 edges on average, with at least 94% accuracy. This indicates Although GNNs have shown state-of-the-art results on tasks that our method primarily removes edges that are crucial for the involving graph data [3, 38], existing methods for explaining the original predictions, resulting in minimal CF explanations. predictions of GNNs have primarily focused on generating sub- graphs that are relevant for a particular prediction [1, 4, 14, 16, 20, ACM Reference Format: 22, 30, 34, 36, 37]. However, none of these methods are able to iden- Ana Lucic, Maartje ter Hoeve, Gabriele Tolomei, Maarten de Rijke, and Fab- tify the minimal subgraph automatically since they all require the rizio Silvestri. 2021. CF-GNNExplainer: Counterfactual Explanations for Graph Neural Networks. In DLG-KDD’21: Deep Learning on Graphs, August user to specify in advance the size of the subgraph, (. We show that 14–18, 2021, Online. ACM, New York, NY, USA, 6 pages. https://doi.org/ even if we adapt existing methods to the CF explanation problem, 10.1145/1122445.1122456 and try varying values for (, such methods are not able to produce valid, accurate explanations, and are therefore not well-suited to 1 INTRODUCTION solve the CF explanation problem. To address this gap, we propose CF-GNNExplainer, which generates CF explanations for GNNs. Advances in machine learning (ML) have led to breakthroughs in Similar to other CF methods proposed in the literature [12, 29], several areas of science and engineering, ranging from computer CF-GNNExplainer works by perturbing data at the instance-level. vision, to natural language processing, to conversational assistants. In this work, the instances are nodes in the graph since we focus on Parallel to the increased performance of ML systems, there is an node classification. In particular, our method iteratively removes increasing call for the “understandability” of ML models [7]. Un- edges from the original adjacency matrix based on matrix sparsi- derstanding why an ML model returns a certain output in response fication techniques, keeping track of the perturbations that lead to a given input is important for a variety of reasons such as model to a change in prediction, and returning the perturbation with debugging, aiding decison-making, or fulfilling legal requirements the smallest change w.r.t. the number of edges. We evaluate CF- [5]. Having certified methods for interpreting ML predictions will GNNExplainer on three public datasets for GNN explanations and help enable their use across a variety of applications [18]. measure its effectiveness using four metrics: fidelity, explanation Explainable AI (XAI) refers to the set of techniques “focused on size, sparsity, and accuracy. CF-GNNExplainer is able to generate exposing complex AI models to humans in a systematic and inter- CF examples with at least 94% accuracy, while removing fewer than pretable manner”[21]. A large body of work on XAI has emerged three edges on average. In this work, we (i) formalize the problem in recent years [2, 8]. Counterfactual (CF) explanations are used to of generating CF explanations for GNNs, and (ii) propose a novel explain predictions of individual instances in the form: “If X had method for generating CF explanations for GNNs. been different, Y would not have occurred” [25]. CF explanations are based on CF examples: modified versions of the input sample DLG-KDD’21, August 14–18, 2021, Online Ana Lucic, Maartje ter Hoeve, Gabriele Tolomei, Maarten de Rijke, and Fabrizio Silvestri 2 PROBLEM FORMULATION we want to find % that minimally perturbs 퐴E, and use it to compute ¯ Let 5 (퐴, -;, ) be any GNN, where 퐴 is the adjacency matrix, - is 퐴E = % ⊙퐴E. If an element %8,9 = 0, this results in the deletion of the the feature matrix, and , is the learned weight matrix of 5 (i.e., edge between node 8 and node 9. When % is a matrix of ones, this 퐴 and - are the inputs of 5 , and 5 is parameterized by , ). We indicates that all edges in 퐴E are used in the forward pass. Similar define the subgraph neighbourhood of a node E as a tuple of the to Srinivas et al. [24], we first generate an intermediate, real-valued ˆ nodes and edges relevant for the computation of 5 (E) (i.e., those matrix % with entries in »0, 1¼, apply a sigmoid transformation, in the ℓ-hop neighbourhood of 5 ): GE = (퐴E,-E), where 퐴E is the then threshold the entries to arrive at a binary %: entries above 0.5 subgraph adjacency matrix and -E is the node feature matrix for become 1, while those below 0.5 become 0. In the case of undirected nodes that are at most ℓ hops away from E. We define a node E as a graphs (i.e., those with symmetric adjacency matrices), instead of ˆ tuple of the form E = (퐴E,G), where G is the feature vector for E. generating % directly, we first generate a perturbation vector which In general, a CF example G¯ for an instance G according to a we then use to populate %ˆ in a symmetric manner. trained classifier 5 is found by perturbing the features of G such that 5 (G) =6 5 (G¯) [31]. An optimal CF example G¯∗ is one that minimizes 3.2 Counterfactual Generating Model the distance between the original instance and the CF example, We want our perturbation matrix % to only act on 퐴E, not 퐴˜E, in according to some distance function 3. The resulting optimal CF order to preserve self-loops in the message passing of 5 (i.e., we ∗ ∗ explanation is ∆G = G¯ − G [15]. always want a node representation update to include the node’s For graph data, it may not be enough to simply perturb node representation from the previous layer). Therefore, we first rewrite features, especially since they are not always available. This is why Equation 2 for our illustrative one-layer case to isolate 퐴E: we are interested in generating CF examples by perturbing the h −1/2 −1/2 i graph structure instead. In other words, we want to change the 5 (퐴E,-E;, ) = softmax (퐷E + 퐼) (퐴E + 퐼)(퐷E + 퐼) -E, (3) relationships between instances, rather than change the instances To generate CFs, we propose a new function 6, which is based on 5 , themselves. Therefore, a CF example for graph data has the form but it is parameterized by % instead of by , . We update the degree E¯ = (퐴¯ ,G), where G is the feature vector and 퐴¯ is a perturbed E E matrix 퐷E based on % ⊙ 퐴E, add the identity matrix to account for version of 퐴 . 퐴¯ is obtained by removing some edges from 퐴 , E E E self-loops (as in 퐷˜E in Equation 2), and call this 퐷¯E: such that 5 (E) =6 5 (E¯). h −1/2 −1/2 i Following Wachter et al. [31] and Lucic et al. [15], we generate (퐴E,-E,, ; %) = softmax 퐷¯E (% ⊙ 퐴E + 퐼)퐷¯E -E, . (4) CF examples by minimizing a loss function of the form: In other words, 5 learns the weight matrix while holding the data L = L?A43 (E, E¯ j 5 ,6) + VL38BC (E, E¯), (1) constant, while 6 is optimized to find a perturbation matrix that where E is the original node, 5 is the original model, 6 is the CF is then used to generate new data points (i.e., CF examples) while holding the weight matrix (i.e., model) constant. Another distinc- model that generates E¯, and L?A43 is a prediction loss that encour- ages 5 (E) =6 5 (E¯).