{"title": "RETAIN: An Interpretable Predictive Model for Healthcare using Reverse Time Attention Mechanism", "book": "Advances in Neural Information Processing Systems", "page_first": 3504, "page_last": 3512, "abstract": "Accuracy and interpretability are two dominant features of successful predictive models. Typically, a choice must be made in favor of complex black box models such as recurrent neural networks (RNN) for accuracy versus less accurate but more interpretable traditional models such as logistic regression. This tradeoff poses challenges in medicine where both accuracy and interpretability are important. We addressed this challenge by developing the REverse Time AttentIoN model (RETAIN) for application to Electronic Health Records (EHR) data. RETAIN achieves high accuracy while remaining clinically interpretable and is based on a two-level neural attention model that detects influential past visits and significant clinical variables within those visits (e.g. key diagnoses). RETAIN mimics physician practice by attending the EHR data in a reverse time order so that recent clinical visits are likely to receive higher attention. RETAIN was tested on a large health system EHR dataset with 14 million visits completed by 263K patients over an 8 year period and demonstrated predictive accuracy and computational scalability comparable to state-of-the-art methods such as RNN, and ease of interpretability comparable to traditional models.", "full_text": "RETAIN: An Interpretable Predictive Model for\n\nHealthcare using Reverse Time Attention Mechanism\n\nEdward Choi\u21e4, Mohammad Taha Bahadori\u21e4, Joshua A. Kulas\u21e4,\n\nAndy Schuetz\u2020, Walter F. Stewart\u2020, Jimeng Sun\u21e4\n\u21e4 Georgia Institute of Technology\n\u2020 Sutter Health\n\n{mp2893,bahadori,jkulas3}@gatech.edu,\n\n{schueta1,stewarwf}@sutterhealth.org, jsun@cc.gatech.edu\n\nAbstract\n\nAccuracy and interpretability are two dominant features of successful predictive\nmodels. Typically, a choice must be made in favor of complex black box models\nsuch as recurrent neural networks (RNN) for accuracy versus less accurate but\nmore interpretable traditional models such as logistic regression. This tradeoff\nposes challenges in medicine where both accuracy and interpretability are impor-\ntant. We addressed this challenge by developing the REverse Time AttentIoN\nmodel (RETAIN) for application to Electronic Health Records (EHR) data. RETAIN\nachieves high accuracy while remaining clinically interpretable and is based on\na two-level neural attention model that detects in\ufb02uential past visits and signi\ufb01-\ncant clinical variables within those visits (e.g. key diagnoses). RETAIN mimics\nphysician practice by attending the EHR data in a reverse time order so that recent\nclinical visits are likely to receive higher attention. RETAIN was tested on a large\nhealth system EHR dataset with 14 million visits completed by 263K patients over\nan 8 year period and demonstrated predictive accuracy and computational scalabil-\nity comparable to state-of-the-art methods such as RNN, and ease of interpretability\ncomparable to traditional models.\n\n1\n\nIntroduction\n\nThe broad adoption of Electronic Health Record (EHR) systems has opened the possibility of\napplying clinical predictive models to improve the quality of clinical care. Several systematic reviews\nhave underlined the care quality improvement using predictive analysis [7, 25, 5, 20]. EHR data\ncan be represented as temporal sequences of high-dimensional clinical variables (e.g., diagnoses,\nmedications and procedures), where the sequence ensemble represents the documented content of\nmedical visits from a single patient. Traditional machine learning tools summarize this ensemble into\naggregate features, ignoring the temporal and sequence relationships among the feature elements.\nThe opportunity to improve both predictive accuracy and interpretability is likely to derive from\neffectively modeling temporality and high-dimensionality of these event sequences.\nAccuracy and interpretability are two dominant features of successful predictive models. There is a\ncommon belief that one has to trade accuracy for interpretability using one of three types of traditional\nmodels [6]: 1) identifying a set of rules (e.g. via decision trees [27]), 2) case-based reasoning by\n\ufb01nding similar patients (e.g. via k-nearest neighbors [18] and distance metric learning [36]), and 3)\nidentifying a list of risk factors (e.g. via LASSO coef\ufb01cients [15]). While interpretable, all of these\nmodels rely on aggregated features, ignoring the temporal relation among features inherent to EHR\ndata. As a consequence, model accuracy is sub-optimal. Latent-variable time-series models, such as\n[34, 35], account for temporality, but often have limited interpretation due to abstract state variables.\nRecently, recurrent neural networks (RNN) have been successfully applied in modeling sequential\nEHR data to predict diagnoses [30] and disease progression [11, 14]. But, the gain in accuracy from\n\n30th Conference on Neural Information Processing Systems (NIPS 2016), Barcelona, Spain.\n\n\fLess interpretable\nLess interpretable\n\nEnd-to-End\nEnd-to-End\n\n\t\t/*\u2a00\n\t\t/*\u2a00\n\t\t$*\n\t\t$*\n\t\t\"*\n\t\t\"*\n\n\t\t'*\n\t\t'*\n\t\t+*\n\t\t+*\n\n\t\t&*\n\t\t&*\n\t\t+*\n\t\t+*\n\nInterpretable\nInterpretable\nEnd-to-End\nEnd-to-End\n\n\t\t/*\u2a00\n\t\t/*\u2a00\n\t\t$*\n\t\t$*\n\t\t\"*\n\t\t\"*\n\n\t\t'*\n\t\t'*\n\t\t,*\n\t\t,*\n\n(a) Standard attention model\n\n(b) RETAIN model\n\nFigure 1: Common attention models vs. RETAIN, using folded diagrams of RNNs. (a) Standard\nattention mechanism: the recurrence on the hidden state vector vi hinders interpretation of the model.\n(b) Attention mechanism in RETAIN: The recurrence is on the attention generation components (hi or\ngi) while the hidden state vi is generated by a simpler more interpretable output.\n\nuse of RNNs is at the cost of model output that is notoriously dif\ufb01cult to interpret. While there have\nbeen several attempts at directly interpreting RNNs [19, 26, 8], these methods are not suf\ufb01ciently\ndeveloped for application in clinical care.\nWe have addressed this limitation using a modeling strategy known as RETAIN, a two-level neural\nattention model for sequential data that provides detailed interpretation of the prediction results while\nretaining the prediction accuracy comparable to RNN. To this end, RETAIN relies on an attention\nmechanism modeled to represent the behavior of physicians during an encounter. A distinguishing\nfeature of RETAIN (see Figure 1) is to leverage sequence information using an attention generation\nmechanism, while learning an interpretable representation. And emulating physician behaviors,\nRETAIN examines a patient\u2019s past visits in reverse time order, facilitating a more stable attention\ngeneration. As a result, RETAIN identi\ufb01es the most meaningful visits and quanti\ufb01es visit speci\ufb01c\nfeatures that contribute to the prediction.\nRETAIN was tested on a large health system EHR dataset with 14 million visits completed by 263K\npatients over an 8 year period. We compared predictive accuracy of RETAIN to traditional machine\nlearning methods and to RNN variants using a case-control dataset to predict a future diagnosis of\nheart failure. The comparative analysis demonstrates that RETAIN achieves comparable performance\nto RNN in both accuracy and speed and signi\ufb01cantly outperforms traditional models. Moreover,\nusing a concrete case study and visualization method, we demonstrate how RETAIN offers an intuitive\ninterpretation.\n\n2 Methodology\n\nWe \ufb01rst describe the structure of sequential EHR data and our notation, then follow with a general\nframework for predictive analysis in healthcare using EHR, followed by details of the RETAIN method.\nEHR Structure and our Notation. The EHR data of each patient can be represented as a time-\nlabeled sequence of multivariate observations. Assuming we use r different variables, the n-th patient\nof N total patients can be represented by a sequence of T (n) tuples (t(n)\n) 2 R \u21e5 Rr, i =\n1, . . . , T (n). The timestamps t(n)\ndenotes the time of the i-th visit of the n-th patient and T (n) the\nnumber of visits of the n-th patient. To minimize clutter, we describe the algorithms for a single patient\nand have dropped the superscript (n) whenever it is unambiguous. The goal of predictive modeling\nis to predict the label at each time step yi 2{ 0, 1}s or at the end of the sequence y 2{ 0, 1}s. The\nnumber of labels s can be more than one.\nFor example, in disease progression modeling (DPM) [11], each visit of a patient visit sequence is\nrepresented by a set of varying number of medical codes {c1, c2, . . . , cn}. cj is the j-th code from\nthe vocabulary C. Therefore, in DPM, the number of variables r = |C| and input xi 2{ 0, 1}|C| is\n\n, x(n)\n\ni\n\ni\n\ni\n\n2\n\n\fa binary vector where the value one in the j-th coordinate indicates that cj was documented in i-th\nvisit. Given a sequence of visits x1, . . . , xT , the goal of DPM is, for each time step i, to predict the\ncodes occurring at the next visit x2, . . . , xT +1, with the number of labels s = |C|.\nIn case of learning to diagnose (L2D) [30], the input vector xi consists of continuous clinical measures.\nIf there are r different measurements, then xi 2 Rr. The goal of L2D is, given an input sequence\nx1, . . . , xT , to predict the occurrence of a speci\ufb01c disease (s = 1) or multiple diseases (s > 1).\nWithout loss of generality, we will describe the algorithm for DPM, as L2D can be seen as a special\ncase of DPM where we make a single prediction at the end of the visit sequence.\nIn the rest of this section, we will use the abstract symbol RNN to denote any recurrent neural\nnetwork variants that can cope with the vanishing gradient problem [3], such as LSTM [23], GRU\n[9], and IRNN [29], with any depth (number of hidden layers).\n\n2.1 Preliminaries on Neural Attention Models\n\nAttention based neural network models are being successfully applied to image processing [1, 32, 21,\n37], natural language processing [2, 22, 33] and speech recognition [12]. The utility of the attention\nmechanism can be seen in the language translation task [2] where it is inef\ufb01cient to represent an\nentire sentence with one \ufb01xed-size vector because neural translation machines \ufb01nds it dif\ufb01cult to\ntranslate the given sentence represented by a single vector.\nIntuitively, the attention mechanism for language translation works as follows: given a sentence of\nlength S in the original language, we generate h1, . . . , hS, to represent the words in the sentence. To\n\ufb01nd the j-th word in the target language, we generate attentions \u21b5j\ni for i = 1, . . . , S for each word in\ni hi and use it to predict the\nj-th word in the target language. In general, the attention mechanism allows the model to focus on a\nspeci\ufb01c word (or words) in the given sentence when generating each word in the target language.\nWe rely on a conceptually similar temporal attention mechanism to generate interpretable prediction\nmodels using EHR data. Our model framework is motivated by and mimics how doctors attend to a\npatient\u2019s needs and explore the patient record, where there is a focus on speci\ufb01c clinical information\n(e.g., key risk factors) working from the present to the past.\n\nthe original sentence. Then, we compute the context vector cj =Pi \u21b5j\n\n2.2 Reverse Time Attention Model RETAIN\nFigure 2 shows the high-level overview of our model, where a central feature is to delegate a\nconsiderable portion of the prediction responsibility to the process for generating attention weights.\nThis is intended to address, in part, the dif\ufb01culty with interpreting RNNs where the recurrent weights\nfeed past information to the hidden layer. Therefore, to consider both the visit-level and the variable-\nlevel (individual coordinates of xi) in\ufb02uence, we use a linear embedding of the input vector xi. That\nis, we de\ufb01ne\n\nvi = Wembxi,\n\n(Step 1)\nwhere vi 2 Rm denotes the embedding of the input vector xi 2 Rr, m the size of the embedding di-\nmension, Wemb 2 Rm\u21e5r the embedding matrix to learn. We can alternatively use more sophisticated\nyet interpretable representations such as those derived from multilayer perceptron (MLP) [13, 28].\nMLP has been used for representation learning in EHR data [10].\nWe use two sets of weights, one for the visit-level attention and the other for variable-level attention,\nrespectively. The scalars \u21b51, . . . ,\u21b5 i are the visit-level attention weights that govern the in\ufb02uence of\neach visit embedding v1, . . . , vi. The vectors 1, . . . , i are the variable-level attention weights that\nfocus on each coordinate of the visit embedding v1,1, v1,2, . . . , v1,m, . . . , vi,1, vi,2, . . . , vi,m.\nWe use two RNNs, RNN\u21b5 and RNN, to separately generate \u21b5\u2019s and \u2019s as follows,\n\ngi, gi1, . . . , g1 = RNN\u21b5(vi, vi1, . . . , v1),\n\nj = 1, . . . , i\n\nej = w>\u21b5 gj + b\u21b5,\n\nfor\n\n\u21b51,\u21b5 2, . . . ,\u21b5 i = Softmax(e1, e2, . . . , ei)\nhi, hi1, . . . , h1 = RNN(vi, vi1, . . . , v1)\nfor\n\nj = tanhWhj + b\n\n3\n\nj = 1, . . . , i,\n\n(Step 2)\n\n(Step 3)\n\n\f\u03a3\t.*\n\n\u2a00 \u2a00\n\n4\n\n\t\t$#\n\t\t\"#\n\n\t\t$)\n\")\n\n\t/*\n\u2a00\n\n\t\t$*\n\t\t\"*\n\n5\n\n\t\t'#\n\t\t+#\n\n\t\t')\n\t+)\n\n2\n\n\t\t'*\n\t\t+*\n0112\n\n3\n\n\t\t&#\n\t\t&)\n\t\t,#\n\t,)\n011&\n\n\t\t&*\n\t\t,*\n\nTime\n\n1\n\nFigure 2: Unfolded view of RETAIN\u2019s architecture: Given input sequence x1, . . . , xi, we predict the\nlabel yi. Step 1: Embedding, Step 2: generating \u21b5 values using RNN\u21b5, Step 3: generating  values\nusing RNN, Step 4: Generating the context vector using attention and representation vectors, and\nStep 5: Making prediction. Note that in Steps 2 and 3 we use RNN in the reversed time.\n\nwhere gi 2 Rp is the hidden layer of RNN\u21b5 at time step i, hi 2 Rq the hidden layer of RNN\nat time step i and w\u21b5 2 Rp, b\u21b5 2 R, W 2 Rm\u21e5q and b 2 Rm are the parameters to learn.\nThe hyperparameters p and q determine the hidden layer size of RNN\u21b5 and RNN, respectively.\nNote that for prediction at each timestamp, we generate a new set of attention vectors \u21b5 and . For\nsimplicity of notation, we do not include the index for predicting at different time steps. In Step 2,\nwe can use Sparsemax [31] instead of Softmax for sparser attention weights.\nAs noted, RETAIN generates the attention vectors by running the RNNs backward in time; i.e., RNN\u21b5\nand RNN both take the visit embeddings in a reverse order vi, vi1, . . . , v1. Running the RNN\nin reversed time order also offers computational advantages since the reverse time order allows us\nto generate e\u2019s and \u2019s that dynamically change their values when making predictions at different\ntime steps i = 1, 2, . . . , T . This ensures that the attention vectors are modi\ufb01ed at each time step,\nincreasing the computational stability of the attention generation process.1\nUsing the generated attentions, we obtain the context vector ci for a patient up to the i-th visit as\nfollows,\n\nci =\n\n\u21b5jj  vj,\n\n(Step 4)\n\niXj=1\n\nwhere  denotes element-wise multiplication. We use the context vector ci 2 Rm to predict the true\nlabel yi 2{ 0, 1}s as follows,\n(Step 5)\nwhere W 2 Rs\u21e5m and b 2 Rs are parameters to learn. We use the cross-entropy to calculate the\nclassi\ufb01cation loss as follows,\n\nbyi = Softmax(Wci + b),\nNXn=1\n\nT (n)Xi=1\u21e3y>i\n\n1\n\nT (n)\n\nL(x1, . . . , xT ) = \n\n1\nN\n\nlog(byi) + (1  yi)> log(1 byi)\u2318\n\n(1)\n\nwhere we sum the cross entropy errors from all dimensions ofbyi. In case of real-valued output\n\nyi 2 Rs, we can change the cross-entropy in Eq. (1) to, for example, mean squared error.\nOverall, our attention mechanism can be viewed as the inverted architecture of the standard attention\nmechanism for NLP [2] where the words are encoded by RNN and the attention weights are generated\nby MLP. In contrast, our method uses MLP to embed the visit information to preserve interpretability\nand uses RNN to generate two sets of attention weights, recovering the sequential information as\nwell as mimicking the behavior of physicians. Note that we did not use the timestamp of each visit\nin our formulation. Using timestamps, however, provides a small improvement in the prediction\nperformance. We propose a method to use timestamps in Appendix A.\n\n1For example, feeding visit embeddings in the original order to RNN\u21b5 and RNN will generate the same e1\nand 1 for every time step i = 1, 2, . . . , T . Moreover, in many cases, a patient\u2019s recent visit records deserve\nmore attention than the old records. Then we need to have ej+1 > ej which makes the process computationally\nunstable for long sequences.\n\n4\n\n\f3\n\nInterpreting RETAIN\n\nFinding the visits that contribute to prediction are derived using the largest \u21b5i, which is straightforward.\nHowever, \ufb01nding in\ufb02uential variables is slightly more involved as a visit is represented by an ensemble\nof medical variables, each of which can vary in its predictive contribution. The contribution of each\nvariable is determined by v,  and \u21b5, and interpretation of \u21b5 alone informs which visit is in\ufb02uential\nin prediction but not why.\nWe propose a method to interpret the end-to-end behavior of RETAIN. By keeping \u21b5 and  values\n\ufb01xed as the attention of doctors, we analyze changes in the probability of each label yi,1, . . . , yi,s\nin relation to changes in the original input x1,1, . . . , x1,r, . . . , xi,1, . . . , xi,r. The xj,k that yields the\nlargest change in yi,d will be the input variable with highest contribution. More formally, given the\nsequence x1, . . . , xi, we are trying to predict the probability of the output vector yi 2{ 0, 1}s, which\ncan be expressed as follows\n(2)\nwhere ci 2 Rm denotes the context vector. According to Step 4, ci is the sum of the visit embeddings\nv1, . . . , vi weighted by the attentions \u21b5\u2019s and \u2019s. Therefore Eq (2) can be rewritten as follows,\n\np(yi|x1, . . . , xi) = p(yi|ci) = Softmax (Wci + b)\n\np(yi|x1, . . . , xi) = p(yi|ci) = Softmax\u2713W\u21e3 iXj=1\n\n\u21b5jj  vj\u2318 + b\u25c6\n\n(3)\n\nUsing the fact that the visit embedding vi is the sum of the columns of Wemb weighted by each\nelement of xi, Eq (3) can be rewritten as follows,\n\np(yi|x1, . . . , xi) = Softmax\u2713W\u21e3 iXj=1\n= Softmax\u2713 iXj=1\nrXk=1\n\nrXk=1\n\n\u21b5jj \n\nxj,kWemb[:, k]\u2318 + b\u25c6\nxj,k \u21b5jW\u21e3j  Wemb[:, k]\u2318 + b\u25c6\n\nwhere xj,k is the k-th element of the input vector xj. Eq (4) can be completely deconstructed to the\nvariables at each input x1, . . . , xi, which allows for calculating the contribution ! of the k-th variable\nof the input xj at time step j \uf8ff i, for predicting yi as follows,\n\n(4)\n\n(5)\n\n!(yi, xj,k) = \u21b5jW(j  Wemb[:, k])\n\nContribution coef\ufb01cient\n\nInput value\n\n,\n\nxj,k|{z}\n\n}\n\n|\n\n{z\n\nwhere the index i of yi is omitted in the \u21b5j and j. As we have described in Section 2.2, we are\ngenerating \u21b5\u2019s and \u2019s at time step i in the visit sequence x1, . . . , xT . Therefore the index i is always\nassumed for \u21b5\u2019s and \u2019s. Additionally, Eq (5) shows that when we are using a binary input value, the\ncoef\ufb01cient itself is the contribution. However, when we are using a non-binary input value, we need\nto multiply the coef\ufb01cient and the input value xj,k to correctly calculate the contribution.\n\n4 Experiments\n\nWe compared performance of RETAIN to RNNs and traditional machine learning methods. Given\nspace constraints, we only report the results on the learning to diagnose (L2D) task and summarize the\ndisease progression modeling (DPM) in Appendix C. The RETAIN source code is publicly available\nat https://github.com/mp2893/retain.\n4.1 Experimental setting\nSource of data: The dataset consists of electronic health records from Sutter Health. The patients\nare 50 to 80 years old adults chosen for a heart failure prediction model study. From the encounter\nrecords, medication orders, procedure orders and problem lists, we extracted visit records consisting\nof diagnosis, medication and procedure codes. To reduce the dimensionality while preserving the\nclinical information, we used existing medical groupers to aggregate the codes into input variables.\nThe details of the medical groupers are given in the Appendix B. A pro\ufb01le of the dataset is summarized\nin Table 1.\n\n5\n\n\fTable 1: Statistics of EHR dataset. (D:Diagnosis, R:Medication, P:Procedure)\n\n# of patients\n# of visits\nAvg. # of visits per patient\n# of medical code groups\n\n615 (D:283, R:94, P:238) Max # of Dx in a visit\n\n263,683\n14,366,030\n\n54.48\n\nAvg. # of codes in a visit\nMax # of codes in a visit\nAvg. # of Dx codes in a visit\n\n3.03\n62\n1.83\n42\n\nImplementation details: We implemented RETAIN with Theano 0.8 [4]. For training the model, we\nused Adadelta [38] with the mini-batch of 100 patients. The training was done in a machine equipped\nwith Intel Xeon E5-2630, 256GB RAM, two Nvidia Tesla K80\u2019s and CUDA 7.5.\nBaselines: For comparison, we completed the following models.\n\u2022 Logistic regression (LR): We compute the counts of medical codes for each patient based on all\nher visits as input variables and normalize the vector to zero mean and unit variance. We use the\nresulting vector to train the logistic regression.\n\nthe input and output.\n\n\u2022 MLP: We use the same feature construction as LR, but put a hidden layer of size 256 between\n\u2022 RNN: RNN with two hidden layers of size 256 implemented by the GRU. Input sequences\nx1, . . . , xi are used. Logistic regression is applied to the top hidden layer. We use two layers of\nRNN of to match the model complexity of RETAIN.\n\n\u2022 RNN+\u21b5M: One layer single directional RNN (hidden layer size 256) along time to generate the\ninput embeddings v1, . . . , vi. We use the MLP with a single hidden layer of size 256 to generate\nthe visit-level attentions \u21b51, . . . ,\u21b5 i. We use the input embeddings v1, . . . , vi as the input to the\nMLP. This baseline corresponds to Figure 1a.\n\n\u2022 RNN+\u21b5R: This is similar to RNN+\u21b5M but use the reverse-order RNN (hidden layer size 256)\nto generate the visit-level attentions \u21b51, . . . ,\u21b5 i. We use this baseline to con\ufb01rm the effectiveness\nof generating the attentions using reverse time order.\n\nThe comparative visualization of the baselines are provided in Appendix D. We use the same\nimplementation and training method for the baselines as described above. The details on the hyper-\nparameters, regularization and drop-out strategies for the baselines are described in Appendix B.\nEvaluation measures: Model accuracy was measured by:\n\u2022 Negative log-likelihood that measures the model loss on the test set. The loss can be calculated\n\nby Eq (1).\n\n\u2022 Area Under the ROC Curve (AUC) of comparing byi with the true label yi. AUC is more\n\nrobust to imbalanced positive/negative prediction labels, making it appropriate for evaluation of\nclassi\ufb01cation accuracy in the heart failure prediction task.\n\nWe also report the bootstrap (10,000 runs) estimate of the standard deviation of the evaluation\nmeasures.\n\n4.2 Heart Failure Prediction\n\nObjective: Given a visit sequence x1, . . . , xT , we predicted if a primary care patient will be\ndiagnosed with heart failure (HF). This is a special case of DPM with a single disease outcome at\nthe end of the sequence. Since this is a binary prediction task, we use the logistic sigmoid function\ninstead of the Softmax in Step 5.\nCohort construction: From the source dataset, 3,884 cases are selected and approximately 10\ncontrols are selected for each case (28,903 controls). The case/control selection criteria are fully\ndescribed in the supplementary section. Cases have index dates to denote the date they are diagnosed\nwith HF. Controls have the same index dates as their corresponding cases. We extract diagnosis codes,\nmedication codes and procedure codes in the 18-months window before the index date.\nTraining details: The patient cohort was divided into the training, validation and test sets in a\n0.75:0.1:0.15 ratio. The validation set was used to determine the values of the hyper-parameters. See\nAppendix B for details of hyper-parameter tuning.\n\n6\n\n\fModel\nLR\nMLP\nRNN\nRNN+\u21b5M\nRNN+\u21b5R\nRETAIN\n\nTable 2: Heart failure prediction performance of RETAIN and the baselines\n\nTest Neg Log Likelihood\n\nAUC\n\nTrain Time / epoch Test Time\n\n0.3269 \u00b1 0.0105\n0.2959 \u00b1 0.0083\n0.2577 \u00b1 0.0082\n0.2691 \u00b1 0.0082\n0.2605 \u00b1 0.0088\n0.2562 \u00b1 0.0083\n\n0.7900 \u00b1 0.0111\n0.8256 \u00b1 0.0096\n0.8706 \u00b1 0.0080\n0.8624 \u00b1 0.0079\n0.8717 \u00b1 0.0080\n0.8705 \u00b1 0.0081\n\n0.15s\n0.25s\n10.3s\n6.7s\n10.4s\n10.8s\n\n0.11s\n0.11s\n0.57s\n0.48s\n0.62s\n0.63s\n\nResults: Logistic regression and MLP underperformed compared to the four temporal learning\nalgorithms (Table 2). RETAIN is comparable to the other RNN variants in terms of prediction\nperformance while offering the interpretation bene\ufb01t.\nNote that RNN+\u21b5R model are a degenerated version of RETAIN with only scalar attention, which is\nstill a competitive model as shown in table 2. This con\ufb01rms the ef\ufb01ciency of generating attention\nweights using the RNN. However, RNN+\u21b5R model only provides scalar visit-level attention, which\nis not suf\ufb01cient for healthcare applications. Patients often receives several medical codes at a single\nvisit, and it will be important to distinguish their relative importance to the target. We show such a\ncase study in section 4.3.\nTable 2 also shows the scalability of RETAIN, as its training time (the number of seconds to train\nthe model over the entire training set once) is comparable to RNN. The test time is the number\nof seconds to generate the prediction output for the entire test set. We use the mini-batch of 100\npatients when assessing both training and test times. RNN takes longer than RNN+\u21b5M because of its\ntwo-layer structure, whereas RNN+\u21b5M uses a single layer RNN. The models that use two RNNs\n(RNN, RNN+\u21b5R, RETAIN)2 take similar time to train for one epoch. However, each model required\na different number of epochs to converge. RNN typically takes approximately 10 epochs, RNN+\u21b5M\nand RNN+\u21b5R 15 epochs and RETAIN 30 epochs. Lastly, training the attention models (RNN+\u21b5M,\nRNN+\u21b5R and RETAIN) for DPM would take considerably longer than L2D, because DPM modeling\ngenerates context vectors at each time step. RNN, on the other hand, does not require additional\ncomputation other than embedding the visit to its hidden layer to predict target labels at each time\nstep. Therefore, in DPM, the training time of the attention models will increase linearly in relation to\nthe length of the input sequence.\n4.3 Model Interpretation for Heart Failure Prediction\nWe evaluated the interpretability of RETAIN in the HF prediction task by choosing a HF patient from\nthe test set and calculating the contribution of the variables (medical codes in this case) to diagnostic\nprediction. The patient suffered from skin problems, skin disorder (SD), benign neoplasm (BN),\nexcision of skin lesion (ESL), for some time before showing symptoms of HF, cardiac dysrhythmia\n(CD), heart valve disease (HVD) and coronary atherosclerosis (CA), and then a diagnosis of HF\n(Figure 3). We can see that skin-related codes from the earlier visits made little contribution to HF\nprediction as expected. RETAIN properly puts much attention to the HF-related codes that occurred in\nrecent visits.\nTo con\ufb01rm RETAIN\u2019s ability to exploit the sequence information of the EHR data, we reverse the visit\nsequence of Figure 3a and feed it to RETAIN. Figure 3b shows the contribution of the medical codes\nof the reversed visit record. HF-related codes in the past are still making positive contributions, but\nnot as much as they did in Figure 3a. Figure 3b also emphasizes RETAIN\u2019s superiority to interpretable,\nbut stationary models such as logistic regression. Stationary models often aggregate past information\nand remove the temporality from the input data, which can mistakenly lead to the same risk prediction\nfor Figure 3a and 3b. RETAIN, however, can correctly digest the sequence information and calculates\nthe HF risk score of 9.0%, which is signi\ufb01cantly lower than that of Figure 3a.\nFigure 3c shows how the contributions of codes change when selected medication data are used in\nthe model. We added two medications from day 219: antiarrhythmics (AA) and anticoagulants (AC),\nboth of which are used to treat cardiac dysrhythmia (CD). The two medications make a negative\ncontributions, especially towards the end of the record. The medications decreased the positive\ncontributions of heart valve disease and cardiac dysrhythmia in the last visit. Indeed, the HF risk\n2The RNN baseline uses two layers of RNN, RNN+\u21b5R uses one for visit embedding and one for generating\n\n\u21b5, RETAIN uses each for generating \u21b5 and \n\n7\n\n\f1.5\n\n(a) HF risk: 0.2474\n\nSD: Skin disorder\nESL: Excision of skin lesion\nBN: Benign neoplasm\nAA: Antiarrhythmic medication\nAC: Anticoagulant medication\n\nCD: Cardiac dysrhythmia\nCA: Coronary atherosclerosis\nHVD: Heart valve disorder\n\nHVD\nCD CD\nCA\n\nDiagnosed\nwith HF\n\nCD\nSD\n\nCD\n\nCD\nSD\nESL\n\nSD\n\nTime\n\nTime\n\nTime\n\nSD\n\nESL, BN\n\nSD\nESL\n\nSD\nSD\nBN ESL\n\nSD\n\nHVD\nCD\n\nCD\nCA\n\nCD\nSD\n\nAA, AC\n\n219\n\nCD\n\nCD\nAA, AC AA, ACACAA\nSD\n294\n342 350\n\nSD\nESL\n\nACAA\n\n328\n\n354\n\nn\no\ni\nt\nu\nb\ni\nr\nt\nn\no\nC\n\n0\n-0.5\n1.5\n\nn\no\ni\nt\nu\nb\ni\nr\nt\nn\no\nC\n\n0\n-0.5\n1.5\n\nn\no\ni\nt\nu\nb\ni\nr\nt\nn\no\nC\n\nSD\n\nSD, ESL\n\nSD, BN\n\nSD, ESL\n\nSD, ESL, BN\n\n(b)HF risk: 0.0905\n\nCDHVD\n\nCACD\n\nSD\n\nCD\n\nSD, ESL\n\nCD\n\nCD\n\nSD\n\n(c) HF risk: 0.2165\n\n0\n\nSD\n\nSD, ESL\n\nSD, BN\n\nSD, ESL\n\nSD, ESL, BN\n\n-0.5\n\n0 (day)\n\n57\n\n95\n\n126\n\n171\n\nFigure 3: (a) Temporal visualization of a patient\u2019s visit records where the contribution of variables for\ndiagnosis of heart failure (HF) is summarized along the x-axis (i.e. time) with the y-axis indicating\nthe magnitude of visit and code speci\ufb01c contributions to HF diagnosis. (b) We reverse the order of\nthe visit sequence to see if RETAIN can properly take into account the modi\ufb01ed sequence information.\n(c) Medication codes are added to the visit record to see how it changes the behavior of RETAIN.\nprediction (0.2165) of Figure 3c is lower than that of Figure 3a (0.2474). This suggests that taking\nproper medications can help the patient in reducing their HF risk.\n5 Conclusion\n\nOur approach to modeling event sequences as predictors of HF diagnosis suggest that complex\nmodels can offer both superior predictive accuracy and more precise interpretability. Given the power\nof RNNs for analyzing sequential data, we proposed RETAIN, which preserves RNN\u2019s predictive\npower while allowing a higher degree of interpretation. The key idea of RETAIN is to improve\nthe prediction accuracy through a sophisticated attention generation process, while keeping the\nrepresentation learning part simple for interpretation, making the entire algorithm accurate and\ninterpretable. RETAIN trains two RNN in a reverse time order to ef\ufb01ciently generate the appropriate\nattention variables. For future work, we plan to develop an interactive visualization system for\nRETAIN and evaluating RETAIN in other healthcare applications.\n\nReferences\n[1] J. Ba, V. Mnih, and K. Kavukcuoglu. Multiple object recognition with visual attention. In ICLR, 2015.\n[2] D. Bahdanau, K. Cho, and Y. Bengio. Neural machine translation by jointly learning to align and translate.\n\nIn ICLR, 2015.\n\n[3] Y. Bengio, P. Simard, and P. Frasconi. Learning long-term dependencies with gradient descent is dif\ufb01cult.\n\nNeural Networks, IEEE Transactions on, 5(2):157\u2013166, 1994.\n\n[4] J. Bergstra, O. Breuleux, F. Bastien, P. Lamblin, R. Pascanu, G. Desjardins, J. Turian, D. Warde-Farley,\n\nand Y. Bengio. Theano: a CPU and GPU math expression compiler. In Proceedings of SciPy, 2010.\n\n[5] A. D. Black, J. Car, C. Pagliari, C. Anandan, K. Cresswell, T. Bokun, B. McKinstry, R. Procter, A. Majeed,\nand A. Sheikh. The impact of ehealth on the quality and safety of health care: a systematic overview. PLoS\nMed, 8(1):e1000387, 2011.\n\n[6] R. Caruana, Y. Lou, J. Gehrke, P. Koch, M. Sturm, and N. Elhadad. Intelligible models for healthcare:\n\nPredicting pneumonia risk and hospital 30-day readmission. In KDD, 2015.\n\n[7] B. Chaudhry, J. Wang, S. Wu, M. Maglione, W. Mojica, E. Roth, S. C. Morton, and P. G. Shekelle.\nSystematic review: impact of health information technology on quality, ef\ufb01ciency, and costs of medical\ncare. Annals of internal medicine, 144(10):742\u2013752, 2006.\n\n[8] Z. Che, S. Purushotham, R. Khemani, and Y. Liu. Distilling knowledge from deep networks with\n\napplications to healthcare domain. arXiv preprint arXiv:1512.03542, 2015.\n\n8\n\n\f[9] K. Cho, B. Van Merri\u00ebnboer, C. Gulcehre, D. Bahdanau, F. Bougares, H. Schwenk, and Y. Bengio. Learning\n\nphrase representations using rnn encoder-decoder for statistical machine translation. In EMNLP, 2014.\n\n[10] E. Choi, M. T. Bahadori, E. Searles, C. Coffey, and J. Sun. Multi-layer representation learning for medical\n\n[11] E. Choi, M. T. Bahadori, and J. Sun. Doctor ai: Predicting clinical events via recurrent neural networks.\n\n[12] J. K. Chorowski, D. Bahdanau, D. Serdyuk, K. Cho, and Y. Bengio. Attention-based models for speech\n\n[13] D. Erhan, Y. Bengio, A. Courville, and P. Vincent. Visualizing higher-layer features of a deep network.\n\nconcepts. In KDD, 2016.\n\narXiv preprint arXiv:1511.05942, 2015.\n\nrecognition. In NIPS, pages 577\u2013585, 2015.\n\nUniversity of Montreal, 2009.\n\n[14] C. Esteban, O. Staeck, Y. Yang, and V. Tresp. Predicting clinical events by combining static and dynamic\n\ninformation using recurrent neural networks. arXiv preprint arXiv:1602.02685, 2016.\n\n[15] A. S. Fleisher, B. B. Sowell, C. Taylor, A. C. Gamst, R. C. Petersen, L. J. Thal, and f. t. A. D. C. Study.\nClinical predictors of progression to Alzheimer disease in amnestic mild cognitive impairment. Neurology,\n68(19):1588\u20131595, May 2007.\n\n[16] A. for Healthcare Research and Quality. Clinical classi\ufb01cations software (ccs) for icd-9-cm. https:\n\n//www.hcup-us.ahrq.gov/toolssoftware/ccs/ccs.jsp. Accessed: 2016-04-01.\n\n[17] A. for Healthcare Research and Quality. Clinical classi\ufb01cations software for services and procedures.\nhttps://www.hcup-us.ahrq.gov/toolssoftware/ccs_svcsproc/ccssvcproc.jsp. Accessed:\n2016-04-01.\n\n[18] B. Gallego, S. R. Walter, R. O. Day, A. G. Dunn, V. Sivaraman, N. Shah, C. A. Longhurst, and E. Coiera.\nBringing cohort studies to the bedside: framework for a \u2019green button\u2019 to support clinical decision-making.\nJournal of Comparative Effectiveness Research, pages 1\u20137, May 2015.\n\n[19] J. Ghosh and V. Karamcheti. Sequence learning with recurrent networks: analysis of internal representations.\n\nIn Aerospace Sensing, pages 449\u2013460. International Society for Optics and Photonics, 1992.\n\n[20] C. L. Goldzweig, A. Tow\ufb01gh, M. Maglione, and P. G. Shekelle. Costs and bene\ufb01ts of health information\n\ntechnology: new trends from the literature. Health affairs, 28(2):w282\u2013w293, 2009.\n\n[21] K. Gregor, I. Danihelka, A. Graves, and D. Wierstra. Draw: A recurrent neural network for image\n\ngeneration. arXiv preprint arXiv:1502.04623, 2015.\n\n[22] K. M. Hermann, T. Kocisky, E. Grefenstette, L. Espeholt, W. Kay, M. Suleyman, and P. Blunsom. Teaching\n\nmachines to read and comprehend. In NIPS, pages 1684\u20131692, 2015.\n\n[23] S. Hochreiter and J. Schmidhuber. Long short-term memory. Neural computation, 9(8):1735\u20131780, 1997.\n[24] W. K. C. D. Information. Medi-span electronic drug \ufb01le (med-\ufb01le) v2. http://www.wolterskluwercdi.\n\ncom/drug-data/medi-span-electronic-drug-file/. Accessed: 2016-04-01.\n\n[25] A. K. Jha, C. M. DesRoches, E. G. Campbell, K. Donelan, S. R. Rao, T. G. Ferris, A. Shields, S. Rosenbaum,\n\nand D. Blumenthal. Use of electronic health records in us hospitals. N Engl J Med, 2009.\n\n[26] A. Karpathy, J. Johnson, and F.-F. Li. Visualizing and understanding recurrent networks. arXiv preprint\n\narXiv:1506.02078, 2015.\n\n[27] A. N. Kho, M. G. Hayes, L. Rasmussen-Torvik, J. A. Pacheco, W. K. Thompson, L. L. Armstrong, J. C.\nDenny, P. L. Peissig, A. W. Miller, W.-Q. Wei, S. J. Bielinski, C. G. Chute, C. L. Leibson, G. P. Jarvik,\nD. R. Crosslin, C. S. Carlson, K. M. Newton, W. A. Wolf, R. L. Chisholm, and W. L. Lowe. Use of\ndiverse electronic medical record systems to identify genetic risk for type 2 diabetes within a genome-wide\nassociation study. JAMIA, 19(2):212\u2013218, Apr. 2012.\n\n[28] Q. V. Le. Building high-level features using large scale unsupervised learning. In ICASSP, 2013.\n[29] Q. V. Le, N. Jaitly, and G. E. Hinton. A simple way to initialize recurrent networks of recti\ufb01ed linear units.\n\narXiv preprint arXiv:1504.00941, 2015.\n\n[30] Z. C. Lipton, D. C. Kale, C. Elkan, and R. Wetzell. Learning to Diagnose with LSTM Recurrent Neural\n\n[31] A. F. Martins and R. F. Astudillo. From softmax to sparsemax: A sparse model of attention and multi-label\n\n[32] V. Mnih, N. Heess, A. Graves, et al. Recurrent models of visual attention. In NIPS, 2014.\n[33] A. M. Rush, S. Chopra, and J. Weston. A neural attention model for abstractive sentence summarization.\n\nNetworks. In ICLR, 2016.\n\nclassi\ufb01cation. In ICML, 2016.\n\nIn EMNLP, 2015.\n\n[34] S. Saria, D. Koller, and A. Penn. Learning individual and population level traits from clinical temporal\n\ndata. In NIPS, Predictive Models in Personalized Medicine workshop, 2010.\n\n[35] P. Schulam and S. Saria. A probabilistic graphical model for individualizing prognosis in chronic, complex\n\ndiseases. In AMIA, volume 2015, page 143, 2015.\n\n[36] J. Sun, F. Wang, J. Hu, and S. Edabollahi. Supervised patient similarity measure of heterogeneous patient\n\nrecords. ACM SIGKDD Explorations Newsletter, 14(1):16\u201324, 2012.\n\n[37] K. Xu, J. Ba, R. Kiros, A. Courville, R. Salakhutdinov, R. Zemel, and Y. Bengio. Show, attend and tell:\n\nNeural image caption generation with visual attention. In ICML, 2015.\n\n[38] M. D. Zeiler. Adadelta: an adaptive learning rate method. arXiv preprint arXiv:1212.5701, 2012.\n\n9\n\n\f", "award": [], "sourceid": 1754, "authors": [{"given_name": "Edward", "family_name": "Choi", "institution": "Georgia Institute of Technolog"}, {"given_name": "Mohammad Taha", "family_name": "Bahadori", "institution": "Gatech"}, {"given_name": "Jimeng", "family_name": "Sun", "institution": "Georgia Tech"}, {"given_name": "Joshua", "family_name": "Kulas", "institution": "Georgia Institute of Technology"}, {"given_name": "Andy", "family_name": "Schuetz", "institution": "Sutter Health"}, {"given_name": "Walter", "family_name": "Stewart", "institution": "Sutter Health"}]}