Multiple Instance Learning for ECG Risk Stratification
Total Page:16
File Type:pdf, Size:1020Kb
Proceedings of Machine Learning Research 85:1{15, 2018 Machine Learning for Healthcare Multiple Instance Learning for ECG Risk Stratification Divya Shanmugam [email protected] Massachusetts Institute of Technology Cambridge, MA, USA Davis Blalock [email protected] Massachusetts Institute of Technology Cambridge, MA, USA John Guttag [email protected] Massachusetts Institute of Technology Cambridge, MA, USA Abstract Patients who suffer an acute coronary syndrome are at elevated risk for adverse cardiovas- cular events such as myocardial infarction and cardiovascular death. Accurate assessment of this risk is crucial to their course of care. We focus on estimating a patient's risk of cardiovascular death after an acute coronary syndrome based on a patient's raw electrocar- diogram (ECG) signal. Learning from this signal is challenging for two reasons: 1) positive examples signifying a downstream cardiovascular event are scarce, causing drastic class imbalance, and 2) each patient's ECG signal consists of thousands of heartbeats, accompa- nied by a single label for the downstream outcome. Machine learning has been previously applied to this task, but most approaches rely on hand-crafted features and domain knowl- edge. We propose a method that learns a representation from the raw ECG signal by using a multiple instance learning framework. We present a learned risk score for cardiovascular death that outperforms existing risk metrics in predicting cardiovascular death within 30, 60, 90, and 365 days on a dataset of 5000 patients. 1. Introduction Machine learning has led to improved risk stratification models for a number of outcomes, including stroke (Li et al., 2016), cancer (Heidari et al., 2018), in-hospital mortality (Gong et al., 2017), and treatment resistance (Perlis, 2013). We consider risk stratification for patients who experience an acute coronary syndrome (ACS). An ACS entails an abrupt, reduced blood flow to the heart and refers to three types of coronary artery disease: ST- elevated myocardial infarction, non-ST-elevated myocardial infarction, and unstable angina. Treatment can range from invasive surgery to prescription drugs. Informative short-term risk metrics, can help physicians identify where patients lie on the continuum of acute coronary syndromes and make the necessary care decisions to reduce adverse outcomes. Predicting cardiovascular death (CVD) within a certain number of days of hospital admis- sion is a common task in risk stratification literature (Morrow et al., 2007; Liu et al., 2014; Myers et al., 2017), enabling direct comparison to existing methods. Risk stratification models for cardiovascular outcomes frequently operate on electrocar- diogram (ECG) signals to produce risk scores (Acharya et al., 2003; Rahman et al., 2015; c 2018 D. Shanmugam, D. Blalock & J. Guttag. Multiple Instance Learning for ECG Risk Stratification Kannathal et al., 2003). Incorporating the ECG signal into a risk model is not straightfor- ward, however, because it is a time series, often measured for days at a time. Many risk metrics compute features based on adjacent heartbeats, including the fluctuation of the ST segment (Klingenheben et al., 2000) and the dynamic time warping distance (Syed et al., 2009). These heuristics represent a good estimate of what a risky ECG signal may look like, but it is not well-understood what best indicates higher risk of a cardiovascular event (i.e. cardiovascular death or myocardial infarction). This presents two challenges: 1. Representation. It is not clear what features to extract from each ECG signal. This is made especially difficult by the variable duration of heart beats. 2. Homogeneity. Different levels of CVD risk at the patient level do not necessarily translate to different characteristics at the level of individual heartbeats. A high risk patient might have more worrisome heartbeats than a low risk patient, but it is unlikely that they will have exclusively or even a large number of worrisome heartbeats. Contributions We address these challenges by fusing ideas from two areas of machine learning: deep learning and Multiple Instance Learning (MIL). We address the representa- tion challenge by learning predictive features directly from the data, with no task-specific engineering, using a compact neural network. This is in contrast to most existing ap- proaches, which extract handcrafted features. To address the challenge of homogeneity, we cast the task as a MIL problem. MIL as- sumes that labels for collections of instances are available, but labels for individual instances are not. In our case, the instances are series of consecutive heartbeats, the collections are the instances extracted from the ECG signal for a single patient, and the label is the outcome for the associated patient. We introduce a task-independent method for learning a signal-based risk metric. Using this method, we also present a new ECG risk score that outperforms existing ECG-based risk metrics in terms of AUC for each of 30, 60, 90, and 365 day risk of cardiovascular death. Technical Significance Existing CVD risk metrics rely on hand-selected features. Out- side the area of CVD, existing work in machine learning for risk stratification often relies on end-to-end learning (Geng et al., 2014; Farran et al., 2013). In this work, we present a risk stratification framework for ECG signals that borrows from both of these approaches, in which we learn the relationship between consecutive beats and use a simple, fixed function to aggregate this relationship. We break signal-based risk stratification into three steps: in- stance extraction, classification, and aggregation. The proposed method outperforms both feature-engineered approaches and entirely learned models. Clinical Relevance We aim to reduce the extent to which risk metric development de- pends on time-intensive feature engineering. While we validate our method only on the task of risk stratification for cardiovascular death, we make no ECG-specific choices. This suggests that our approach could also be applied to other problems. While it is well established that the inclusion of patient features outside of the ECG increases the predictive power of a given ECG-based risk metric (Syed et al., 2009; My- ers et al., 2017), we focus on how to incorporate the ECG signal into a risk stratification model. Using the proposed method, we predict cardiovascular death from a raw ECG signal 2 Multiple Instance Learning for ECG Risk Stratification and present a state-of-the-art risk score with respect to CVD within 30, 60, 90 and 365 days. We include a review of related work in Section 2 and a summary of our method in Section 3. In Section 4, we outline our experimental setup and follow with results in Section 5. We conclude with a robustness analysis in Section 6 and discussion in Section 7. 2. Related Work 2.1. CVD Risk Stratification The simplest and most commonly used means of predicting CVD is to construct a model based on easy-to-quantify patient characteristics, such as age, sex, and Left Ventricular Ejection Fraction (LVEF) (Cintron et al., 1993; Antman et al., 2000; Tang et al., 2007; Muntner et al., 2014). There is strong evidence, however, that leveraging ECG signals can add significant predictive power (Liu et al., 2014). Because these signals take the form of long time series, it is not obvious how best to incorporate them into a predictive model. One approach is to treat the entire signal as an input to a model, either by extracting summary statistics (Malik et al., 1996) or feeding it into a recurrent neural network (Myers et al., 2017). An alternative is to treat one ECG signal as a sequence of many examples of heartbeats. A successful means of doing so is to extract pairs of consecutive heartbeats and represent each pair using a set of informative features (McCraty and Shaffer, 2015; Syed et al., 2009). These features might include polynomial fit coefficients (Sun et al., 2012), Legendre coefficients (Myers et al., 2017) and DTW distance (Liu et al., 2014). Each of these methods rely on heuristics for cardiovascular risk. In contrast, we propose a method that learns a function to estimate risk directly from the ECG signal. We build upon existing work by focusing on consecutive heartbeats and map this problem to the multiple instance learning framework. 2.2. Multiple Instance Learning (MIL) MIL tasks involve three entities: instances, collections, and labels. In traditional supervised learning problems, each instance is associated with a label. In MIL, instances belong to collections and each collection is associated with a label (Amores, 2013). A wealth of MIL research has occurred since the field’s inception in 1998 (Maron and Lozano-P´erez, 1998) and we direct interested readers to the survey conducted by (Carbonneau et al., 2018). There are two basic assumptions in MIL: the standard MI assumption and the collective assumption. Our work depends upon the collective assumption, which states that positive and negative collections differ in the percentage of instances that are positive rather than the existence of a single positive instance. Our approach is also an example of single instance learning (Foulds and Frank, 2010), in which each instance inherits the collection label. We focus on the collection label prediction task corresponding to the downstream patient outcome, as opposed to the instance label prediction task. 3 Multiple Instance Learning for ECG Risk Stratification 3. Method We assume a collection of M ECG signals T1;:::;TM . Each signal is associated with a unique patient m, and consists of Lm scalar samples. Each patient is associated with a label ym 2 f0; 1g, where ym = 1 indicates that the patient died of cardiac death within D days of hospital admission. D is set to be 30, 60, 90, or 365 days after hospital adminission. Our task is to predict the label for held-out patients based on their ECG signals. We treat this as a multiple instance learning problem.