Concept-Based Model Explanations for Electronic Health Records
Total Page:16
File Type:pdf, Size:1020Kb
Concept-based model explanations for Electronic Health Records Diana Mincu Eric Loreaux Shaobo Hou Google Research Google Health DeepMind London, UK Palo Alto, CA, USA London, UK Sebastien Baur Ivan Protsyuk Martin Seneviratne Google Health Google Health Google Health London, UK London, UK London, UK Anne Mottram Nenad Tomasev Alan Karthikesalingam DeepMind Deepmind Google Health London, UK London, UK London, UK Jessica Schrouff∗ Google Research London, UK [email protected] ABSTRACT suitable for continuous clinical predictions. While these memory- Recurrent Neural Networks (RNNs) are often used for sequential storing networks allow for accurate and dynamic predictions, it is modeling of adverse outcomes in electronic health records (EHRs) often difficult to examine the mechanism by which clinical infor- due to their ability to encode past clinical states. These deep, recur- mation is being translated into outputs. In healthcare, as in other rent architectures have displayed increased performance compared fields in which trust is paramount, it is not sufficient toshowstate to other modeling approaches in a number of tasks, fueling the in- of the art discriminative performance; clinicians have also deemed terest in deploying deep models in clinical settings. One of the key it critical that models provide local and global explanations for their elements in ensuring safe model deployment and building user trust behavior [34]. is model explainability. Testing with Concept Activation Vectors Multiple approaches have been proposed to provide explana- (TCAV) has recently been introduced as a way of providing human- tions for machine learning models applied to EHR data [see 21, understandable explanations by comparing high-level concepts to for a review], with a focus on attention-based methods when the the network’s gradients. While the technique has shown promising architecture relies on RNNs [e.g., 4, 26, 27]. Typically, those in- results in real-world imaging applications, it has not been applied terpretability techniques ranks the input features based on their to structured temporal inputs. To enable an application of TCAV to attention scores. However, single feature rankings might not high- sequential predictions in the EHR, we propose an extension of the light clinical states that encompass multiple input features (e.g. method to time series data. We evaluate the proposed approach on “infection”) in an intuitive manner. To address this issue of human an open EHR benchmark from the intensive care unit, as well as understandability, Panigutti et al. [20] use an ontology of diagnoses synthetic data where we are able to better isolate individual effects. to provide insights across single features. This approach, however, relies on diagnoses which are typically recorded at the end of an KEYWORDS admission and is therefore not suitable to identify temporal changes across features that reflect a clinical concept, nor is it able to provide arXiv:2012.02308v2 [cs.LG] 8 Mar 2021 Explainability, time series, human-understandable concepts, elec- continuous predictions. tronic health records (EHR) On the other hand, human-understandable explanations have been successfully developed for computer vision applications: Test- 1 INTRODUCTION ing with Concept Activation Vectors [TCAV, 16] relies on human- Wider availability of Electronic Health Records (EHR) has led to understandable “concepts” to derive model explanations. Practition- an increase in machine learning applications for clinical diagnosis ers or end users can select examples from the data that embody and prognosis [e.g., 1, 3]. Larger de-identified datasets and public intuitive concepts (e.g. “pointy ears” or “stripes”), and these exam- benchmarks have fueled the application of increasingly complex ples are then used to map concepts to the model’s activation space techniques such as recurrent neural networks (RNNs) to predict in the form of concept activation vectors (CAVs). CAVs can then be adverse clinical events [e.g., 7, 18, 27, 33, 35]. RNNs can operate over used to provide global explanations, as well as assess the presence a sequence of health information, iteratively combining input data or absence of a concept in local examples. with internal memory states to generate new states, making them In this work, we define “clinical concepts” from temporal EHR input features to improve the human-understandability of post- hoc explanations of continuous clinical predictions. Our approach ∗Corresponding author. ACM CHIL ’21, April 8–10, 2021, Virtual Event, USA Mincu et al. leverages TCAV [16] and can be applied to previously trained mod- • CAVC4=3 −CBC0AC : we record the model’s activations at CBC0AC els without restrictions on model inputs or RNN architecture. Our and at C4=3 and use their difference to train the CAV. In this contributions are as follows: case, we assume that changes in activations represent the concept of interest. • We extend the TCAV approach to the time series setting by defining metrics assessing (1) whether the model encodes A concept is considered as “encoded” in the model if the linear the concept, (2) whether the concept is “present” in examples, model performs significantly above chance level. We assess the lin- and (3) whether a concept influences the model’s predictions. ear classifier’s performance using a bootstrap resampling scheme • We design a synthetic time series dataset to evaluate (concept- (k=100, stratified where relevant) and perform random permuta- based) attribution methods and demonstrate that the pro- tions (1,000 permutations, 10 per bootstrap resampling) of the labels posed technique is faithful. to obtain a null distribution of balanced accuracy and area under the • We propose a framework to define human-understandable receiver-operating curve (AUROC). We assess a CAV as significant concepts in EHR and illustrate it using the de-identified if all metrics are higher than the estimated null distributions with MIMIC-III benchmark dataset [13]. ? < 0.05. We then estimate the generalizability of the classifier across time steps by performing the classification at all time points = 2 METHODS (C 1,...,) ), where we give a label of 1 (resp. 0) at time points where the concept is present (resp. absent), when that information Notation: We consider a set of multivariate time series X := is known (i.e. synthetic data), and a label of 1 (resp. 0) for all time G , where G 2 R, # is the number of time 8,C,3 8 ≤#,C ≤)8,3 ≤퐷 8,C,3 points of concept (resp. control) time series if not known (i.e. clinical series (i.e. patients), 퐷 is the number of features per time step and application). This measure of performance beyond the »CBC0AC ,C4=3 ¼ )8 the number of time steps for patient 8. We define x3 as the time window allows to understand whether concepts are represented series for feature 3 2 f1, . , 퐷g for a single example. The label similarly across all time points in the sequence, or whether the # ×) y 2 f0, 1g exists for all examples and all time steps. We train a signal is specific to the window selected. recurrent neural network 퐹 : X!»0, 1¼) with ! layers. For a given layer 1 ≤ ; ≤ ! and time step 1 ≤ C ≤ ) , we can write the predicted Presence of the concept in a sample: The original TCAV work [16] computes the cosine similarity between the activations 0® of a output of 퐹 as 퐹C ¹®Gº := ℎ¹5; ¹®G1:C ºº where 5; ¹®G1:C º is the activation ; sample and the obtained CAV at each layer to estimate how similar vector at the ;-th layer after C time steps, further referred to as 0®C,; and ℎ represents the operations in layers ;...!. Please note that an image is to a concept. This similarity measure can be thought we consider binary classification settings, but the approach extends of as estimating whether a concept is manifesting or “present” in to multi-class predictions. the sample. In time series, it can be computed at each time point independently to obtain a (local) trajectory of concept presence per 2.1 Concept-based explanations over time layer: In this section, we extend TCAV [16] to account for the temporal ) 0®C dimension. TCAV relies on two main steps: (1) Building a concept tCA퐶 ¹®GC º = E®퐶 activation vector (CAV) for each concept, and (2) assessing how the jj®0C jj2 concept influences the model’s decision. Where E®퐶 corresponds to the unit norm CAV of concept 퐶. This formulation can be extended to estimate whether the activations Building a CAV:. To build a CAV, Kim et al. [16] sample positive change over time in the direction of the concept by replacing 0®C by and negative examples for a concept, record their activations 0®; at »®0C −0 ®C−3C ¼ (following the assumption of local linearity in [16]), each layer ; of the network and build a linear classifier (e.g. logistic where 3C represents a constant lag in a time shifting window. This is regression) distinguishing between activations related to positive relevant to investigate concepts that would vary across time, e.g. by and negative samples. To extend this approach to timeseries, we becoming more severe, and this the formulation used throughout » ¼ identify a ‘time window of interest’ CBC0AC ,C4=3 that reflects a tra- this work. jectory corresponding to a concept, i.e. during which some features Influence of the concept on the model’s prediction: Kim et al. [16] or feature changes are present. We define a ‘control’ group as a set define the Conceptual Sensitivity (CS), to estimate how the model’s of trajectories in which the concept does not manifest. We then gradients align with the CAV. This quantity, when aggregated over collect the model’s activations from C = CBC0AC to the end of the samples, represents a global explanation. Mathematically, CS can window C4=3 for both groups, and training data for CAV learning is be computed as the directional derivative: defined based on three different strategies: • CAVC4=3 : we record the model’s activations in each layer at mℎ¹5; ¹®GC ºº CS퐶,;,C ¹퐹, G®C º := C4=3 .