{"title": "Attentive State-Space Modeling of Disease Progression", "book": "Advances in Neural Information Processing Systems", "page_first": 11338, "page_last": 11348, "abstract": "Models of disease progression are instrumental for predicting patient outcomes and understanding disease dynamics. Existing models provide the patient with pragmatic (supervised) predictions of risk, but do not provide the clinician with intelligible (unsupervised) representations of disease pathophysiology. In this paper, we develop the attentive state-space model, a deep probabilistic model that learns accurate and interpretable structured representations for disease trajectories. Unlike Markovian state-space models, in which the dynamics are memoryless, our model uses an attention mechanism to create \"memoryful\" dynamics, whereby attention weights determine the dependence of future disease states on past medical history. To learn the model parameters from medical records, we develop an infer ence algorithm that simultaneously learns a compiled inference network and the model parameters, leveraging the attentive state-space representation to construct a \"Rao-Blackwellized\" variational approximation of the posterior state distribution. Experiments on data from the UK Cystic Fibrosis registry show that our model demonstrates superior predictive accuracy and provides insights into the progression of chronic disease.", "full_text": "Attentive State-Space Modeling of\n\nDisease Progression\n\nAhmed M. Alaa\nECE Department\n\nUCLA\n\nMihaela van der Schaar\n\nUCLA, University of Cambridge, and\n\nahmedmalaa@ucla.edu\n\n{mv472@cam.ac.uk,mihaela@ee.ucla.edu}\n\nAlan Turing Institute\n\nAbstract\n\nModels of disease progression are instrumental for predicting patient outcomes\nand understanding disease dynamics. Existing models provide the patient with\npragmatic (supervised) predictions of risk, but do not provide the clinician with\nintelligible (unsupervised) representations of disease pathology. In this paper, we\ndevelop the attentive state-space model, a deep probabilistic model that learns\naccurate and interpretable structured representations for disease trajectories. Unlike\nMarkovian state-space models, in which state dynamics are memoryless, our model\nuses an attention mechanism to create \u201cmemoryful\u201d dynamics, whereby attention\nweights determine the dependence of future disease states on past medical history.\nTo learn the model parameters from medical records, we develop an inference algo-\nrithm that jointly learns a compiled inference network and the model parameters,\nleveraging the attentive representation to construct a variational approximation of\nthe posterior state distribution. Experiments on data from the UK Cystic Fibrosis\nregistry show that our model demonstrates superior predictive accuracy, in addition\nto providing insights into disease progression dynamic.\n\n1\n\nIntroduction\n\nChronic diseases \u2014 such as cardiovascular disease, cancer and diabetes \u2014 progress slowly throughout\na patient\u2019s lifetime, causing increasing burden to the patients, their carers, and the healthcare delivery\nsystem [1]. The advent of modern electronic health records (EHR) provides an opportunity for\nbuilding models of disease progression that can predict individual-level disease trajectories, and\ndistill intelligible and actionable representations of disease dynamics [2]. Models that are both\nhighly accurate and capable of extracting knowledge from data are important for informing practice\nguidelines and identifying the patients\u2019 needs and interactions with health services [3, 4, 5, 6].\nIn this paper, we develop a deep probabilistic model of disease progression that capitalizes on both\nthe interpretable structured representations of probabilistic models and the predictive strength of deep\nlearning methods. Our model uses a state-space representation to segment a patient\u2019s disease trajectory\ninto \u201cstages\u201d of progression that manifest through clinical observations. But unlike conventional\nstate-space models, which are predominantly Markovian, our model uses recurrent neural networks\n(RNN) to capture more complex state dynamics. The proposed model learns hidden disease states\nfrom observational data in an unsupervised fashion, and hence it is suitable for EHR data where a\npatient\u2019s record is seldom annotated with \u201clabels\u201d indicating their true health state [7].\nOur model uses an attention mechanism to capture state dynamics, hence we call it an attentive state-\nspace model. The attention mechanism observes the patient\u2019s clinical history, and maps it to attention\nweights that determine how much in\ufb02uence previous disease states have on future state transitions. In\nthat sense, attention weights generated for an individual patient explain the causative and associative\nrelationships between the hidden disease states and the past clinical events for that patient. Because\n\n33rd Conference on Neural Information Processing Systems (NeurIPS 2019), Vancouver, Canada.\n\n\f(a) RNN\n\n(b) HMM\n\n(c) Attentive state space\n\nFigure 1: Sequential data models. (a) Graphical model for an RNN. \u2666 denotes a deterministic representation,\n(b) Graphical model for an HMM. (cid:13) denotes probabilistic states, (c) Graphical depiction of an attentive state\nspace model. With a slight abuse of graphical model notation, thickness of arrows re\ufb02ect attention weights.\n\nthe attention mechanism can be made arbitrarily complex, our model can capture complex dynamics\nwhile maintaining its structural interpretability. We implement this dynamic attention mechanism via\na sequence-to-sequence RNN architecture [8].\nBecause our model is non-Markovian, inference of posterior disease states is intractable and cannot\nbe conducted using standard forward-backward routines (e.g., [9, 10, 11]). To address this issue, we\ndevise a structured inference network trained to predict posterior state distributions by mimicking the\nattentive structure of our model. The inference network shares attention weights with the generative\nmodel, and uses those weights to create summary statistics needed for posterior state inference. We\njointly train the inference and model networks using stochastic gradient descent.\nTo demonstrate the practical signi\ufb01cance of the attentive state-space model, we use it to model\nthe progression trajectories of breast cancer using data from the UK Cystic Fibrosis registry. Our\nexperiments show that attentive state-space models can extract clinically meaningful representations\nof disease progression while maintaining superior predictive accuracy for future outcomes.\nRelated work. Various predictive models based on RNNs have been recently developed for healthcare\nsettings \u2014 e.g., \u201cDoctor AI\u201d [12], \u201cL2D\u201d [13], and \u201cDisease-Atlas\u201d [14]. Unfortunately, RNNs are\nof a \u201cblack-box\u201d nature since their hidden states do not correspond to clinically meaningful variables\n(Figure 1a). Thus, all the aforementioned methods do not provide an intelligible model of disease\nprogression, but are rather limited to predicting a target outcome.\nThere have been various attempts to create interpretable RNN-based predictive models using attention.\nThe models in [15, 16, 17] use a reverse-time attention mechanism to learn visit-level attention weights\nthat explain the predictions of an RNN. The main difference between the way attention is used in our\nmodel and the way it is used in models like \u201cRETAIN\u201d [15] is that our model applies attention to\nthe latent state space, whereas RETAIN applies attention to the observable sample space. Hence,\nattention gives different types of explanations in the two models. In our model, attention interprets\nthe hidden disease dynamics, hence it provides an explanation for the mechanisms underlying disease\nprogression. On the contrary, RETAIN uses attention to measure feature importance, hence it can\nonly explain discriminative predictions, but not the underlying generative disease dynamics.\nAlmost all existing models of disease progression are based on variants of the HMM model [18, 9, 19].\nDisease dynamics in such models are very easily interpretable as they can be perfectly summarized\nthrough a single matrix of probabilities that describes transition rates among disease states. Markovian\ndynamics also simplify inference because the model likelihood factorizes in a way that makes ef\ufb01cient\nforward and backward message passing possible. However, memoryless Markov models assume\nthat a patient\u2019s current state d-separates her future trajectory from her clinical history (Figure 1b).\nThis renders HMM-based models incapable of properly explaining the heterogeneity in the patients\u2019\nprogression trajectories, which often results from their varying clinical histories or the chronologies\n(timing and order) of their clinical events [5]. This limitation is crucial in complex chronic diseases\nthat are accompanied with multiple morbidities. Our model addresses this limitation by creating\nmemoryful state transitions that depend on the patient\u2019s entire clinical history (Figure 1c).\nMost existing works on deep probabilistic models have focused on developing structured inference\nalgorithms for deep Markov models and their variants [20, 10, 21, 22, 23]. All such models use\nneural networks to model the transition and emission distributions, but are limited to Markovian\ndynamics. Other works develop stochastic versions of RNNs for the sake of generative modeling;\nexamples include variational RNNs [24], SRNN [25], and STORN [26]. These models augment\nstochastic layers to an RNN in order to enrich its output distribution. However, transition and\n\n2\n\nHidden states Observations Hidden states Observations Observations Hidden states \fp\u03b8((cid:126)xT , (cid:126)zT ) =\n\nT(cid:89)\n\nt=1\n\n(cid:124)\n\n(cid:123)(cid:122)\n\np\u03b8(xt | zt)\nEmission\n\n(cid:125)\n\n(cid:124)\n\n(cid:123)(cid:122)\n\np\u03b8(zt | (cid:126)xt\u22121, (cid:126)zt\u22121)\n,\n\nTransition\n\n(cid:125)\n\n(2)\n\nemission distributions in such models cannot be decoupled, and hence their latent states do not map\nto clinically meaningful identi\ufb01cation of disease states. To the best of our knowledge, ours is the\n\ufb01rst deep probabilistic model that provides clinically meaningful latent representations, with non-\nMarkovian state dynamics that can be made arbitrarily complex while remaining interpretable.\n\n2 Attentive State-Space Models\n\n2.1 Structure of the EHR Data\n\nA patient\u2019s EHR record, denoted as (cid:126)xT , is a collection of sequential follow-up data gathered during\nrepeated hospital visits. We represent a given patient\u2019s record as\n\n(1)\nwhere xt is the follow-up data collected during the tth hospital visit, and T is the total number of\nvisits. The follow-up data xt \u2208 X is a multi-dimensional vector that comprises information on\nbiomarkers and clinical events, such as treatments and ICD-10 diagnosis codes [2].\n\nt=1,\n\n(cid:126)xT = {xt}T\n\n2.2 Attentive State-Space Representation\n\nWe model the progression trajectory of the target disease via a state-space representation. That is, at\neach time step t, the patient\u2019s health is characterized by a state zt \u2208 Z which manifests through the\nfollow-up data xt. The state space is the (discrete) set of all possible stages of disease progression\nZ = {1, . . . , K}. In general, progression stages correspond to distinct disease phenotypes. For\ninstance, chronic kidney disease progresses through 5 stages (Stage I to Stage IV), each of which\ncorresponds to a different level of renal dysfunction [27]. We assume that {zt}t is hidden, i.e., the\ntrue health state of a patient is not observed, and should be learned in an unsupervised fashion. We\nmodel the joint distribution of states and observations via the following factorization:\n\nwhere (cid:126)zt = {z1, . . ., zt}, 1 \u2264 t \u2264 T , and \u03b8 is the set of parameters of our model.\nAttentive state transitions. What makes the model in (2) differ from standard state-space models?\nThe main difference is that the transition probability in (2) assumes that the patient\u2019s health state\nat time t depends on their entire history ((cid:126)xt\u22121, (cid:126)zt\u22121). This is a major departure from the standard\nMarkovian assumption, which posits that p\u03b8(zt | (cid:126)xt\u22121, (cid:126)zt\u22121) = p\u03b8(zt | zt\u22121), i.e., future states\ndepend only on current state. Most existing disease progression models are Markovian (e.g., [9, 18]).\nTo capture non-Markovian dynamics, we model the state transition distribution as follows:\n\np\u03b8(zt | (cid:126)xt\u22121, (cid:126)zt\u22121) = p\u03b8(zt | (cid:126)\u03b1t, (cid:126)zt\u22121),\n\n(3)\nwhere (cid:126)\u03b1t = {\u03b1t\ni = 1, is a set of attention weights that act as\nsuf\ufb01cient statistics of future states. The attention weights admit to a simple interpretation: they\ndetermine the in\ufb02uences of past state realizations on future state transitions via the linear dynamic\n\nt\u22121}, \u03b1t\n\n1, . . ., \u03b1t\n\ni \u03b1t\n\ni \u2208 [0, 1],\u2200i,(cid:80)\nt\u22121(cid:88)\n\nt(cid:48) P (zt(cid:48), zt), \u2200t \u2265 1,\n\u03b1t\u22121\nwhere P is a baseline state transition matrix, i.e., P = pij \u2208 [0, 1],(cid:80)\n\np\u03b8(zt | (cid:126)\u03b1t, (cid:126)zt\u22121) =\n\nj pij = 1, and the initial state\ndistribution is \u03c0 = [ p1, . . ., pK ]. The attention weights (cid:126)\u03b1t assigned to all previous state realizations\nat time t are generated using the patient\u2019s context (cid:126)xt via an attention mechanism A as follows:\n\nt(cid:48)=1\n\n(4)\n\n(5)\nwhere A is a deterministic algorithm that generates a sequence of functions {At}t, At : X t \u2192 [0, 1]t.\nWe specify our choice of the attention mechanism in Section 2.3.\nEmission distribution. The follow-up data xt = (xc\nt (e.g., biomarkers and test results) and a binary component xb\nxc\n\nt) comprises both a continuous component\nt (e.g., clinical events and ICD-10\n\n(cid:126)\u03b1t = At((cid:126)xt).\n\nt , xb\n\n3\n\n\fcodes). To capture both components, we model the emission distribution in (2) through the following\nfactors p\u03b8(xt | zt) = p\u03b8(xb\n\nt | zt), where\n\nt , zt) \u00b7 p\u03b8(xc\n\nt | xc\n\np\u03b8(xc\n\nt | zt) = N (\u00b5zt, \u03a3zt ), p\u03b8(xb\n\nt | xc\n\nt , zt) = Bernoulli(Logistic(xc\n\nt , \u039bzt)).\n\n(6)\n\nThe model in (6) speci\ufb01es a state-speci\ufb01c distribution for binary (Bernoulli) and continuous (Gaussian)\nvariables, with state-speci\ufb01c emission distribution parameters (\u00b5zt , \u03a3zt, \u039bzt). This, an attentive state-\nspace model can be completely speci\ufb01ed through the parameter set \u03b8 = (\u03c0, P , A, \u00b5, \u03a3, \u039b).\nGenerality of the attentive representation. For particular choices of the attention mechanism in\n(4), our model reduces to various classical models of sequential data as shown in Table 1.\nThe generality of the attentive\nstate representation is a powerful\nfeature because it implies that by\nlearning the attention functions\n{At}t, we are effectively testing\nthe structural assumptions of var-\nious commonly-used time series\nmodels in a data-driven fashion.\n\nj = 0, j \u2264 t \u2212 2.\n\u03b1t\nt\u22121 = 1, \u03b1t\nj = 1{m\u2264j\u2264t\u22121}, j \u2264 t \u2212 2.\n\u03b1t\ni\u2265\u03b3}.\nTable 1: Representation of familiar elementary functions in terms of.\n\nj \u2208 {0, \u00afn\u22121}, \u00afn =(cid:80)\n\nVariable-order HMM [29]\n\nAttention mechanism\n\nOrder-m HMM [28]\n\nHMM [9]\n\nModel\n\n\u03b1t\n\ni 1{\u03b1t\n\n2.3 Sequence-to-sequence Attention Mechanism\n\nTo complete the speci\ufb01cation of our model, we now specify the attention mechanism A in (5). Recall\nthat A is a sequence of deterministic functions that map a patient\u2019s context (cid:126)xt to a set of attention\nweights (cid:126)\u03b1t at each time step. Since our model must output an entire sequence of attention weights\nevery time step, we implement A via a sequence-to-sequence (Seq2Seq) model [8].\n\nOur Seq2Seq model uses LSTM\nencoder-decoder architecture as\nshown in Figure 2. For each time\nstep t, the patient context (cid:126)xt is\nfed to the LSTM encoder, and\nthe \ufb01nal state of the encoder, ht,\nis viewed as a \ufb01xed-size repre-\nsentation of the patient\u2019s context,\nand is passed together with the\nlast output O to the decoder side.\n\nFigure 2: Seq2Seq architecture for the attention mechanism A.\n\nIn the decoding phase, the last state of the encoding LSTM is used as an initial state of the decoding\nLSTM, and O is used as its \ufb01rst input. Then, the decoding LSTM iteratively uses its output at\none time step as its input for the next step. After t \u2212 1 decoding iterations, we collect the t \u2212 1\n(normalized) attention weights via a Softmax output layer.\nThe main difference between our architecture and other Seq2Seq models \u2014 often used in language\ntranslation tasks [30, 8] \u2014 is that in our case, we learn an entire sequence of attention weights for\neach of the T data vectors in (cid:126)xT . We achieve this by running t \u2212 1 decoding iterations to collect\nt \u2212 1 outputs for every single encoding step. Moreover, in our setup attention sequence is the target\nsequence being learned. This should not be confused with other Seq2Seq schemes with attention,\nwhere attention is used as an intermediate representation within a decoding procedure [31].\n\n2.4 Why Attentive State Space Modeling?\n\nMost existing models of disease progression are based on Hidden Markov models [19, 9, 18, 32].\nHowever, the Markovian dynamic is oversimpli\ufb01ed: in reality, a patient transition to a given state\ndepends not only on her current stage, but also on her individual history of past clinical events\n[1]. In this sense, a Markov models is of a \u201cone-size-\ufb01ts-all\u201d nature \u2014 under a Markov model, all\npatients at the same stage of progression would have the same expected future trajectory, irrespective\nof their potentially different individual clinical histories. Because Markov models explain away\nindividual-level variations in progression trajectories, their interpretable nature should be thought\n\n4\n\nEncoder Decoder Attention Weights Softmax Layer \fof as a bug and not a feature, i.e., a Markov model is easily interpretable only because it does not\nexplain much, it only encodes our own prior assumptions about disease dynamics.\nAttentive state space models overcome the shortcomings of Markov models by using attention\nweights to create non-stationary, variable-order generalization of Markovian transitions, whereby the\ndynamics of each patient changes over time based on her individual clinical context. An attentive state\nmodel can learn state dynamics that are as complex as those of an RNN, but through the factorization\nin (2), it ensures that its hidden states correspond to clinically meaningful disease states.\n\n3 Attentive Variational Inference\n\nLearning the model parameter \u03b8 and inferring a patient\u2019s health state in real-time requires computing\nthe posterior p\u03b8((cid:126)zt | (cid:126)xt). However, the non-Markovian nature of our model renders posterior compu-\ntation intractable. In this Section, we develop a variational learning algorithm that jointly learns the\nmodel parameter \u03b8 and a structured inference network that approximates the posterior p\u03b8((cid:126)zt | (cid:126)xt).\nWe show that the attentive representation proposed in Section 2 is useful not only for improving\npredictions and extracting clinical knowledge, but also can help improve structured inference.\n\n3.1 Variational Lower Bound\n\nIn variational learning, we maximize an evidence lower bound (ELBO) for the data likelihood, i.e.,\n\nlog p\u03b8((cid:126)xT ) \u2265 Eq\u03c6 [ log p\u03b8((cid:126)xT , (cid:126)zT ) \u2212 log q\u03c6((cid:126)zT | (cid:126)xT ) ] ,\n\nwhere q\u03c6((cid:126)zT | (cid:126)xT ) is a variational distribution that approximates the posterior p\u03b8((cid:126)zT | (cid:126)xT ). We\nmodel the variational distribution q\u03c6((cid:126)zT | (cid:126)xT ) using an inference network that is trained jointly with\nthe model through the following optimization problem [33, 34]:\n\n\u03b8\u2217, \u03c6\u2217 = arg max\n\n\u03b8,\u03c6\n\nEq\u03c6 [ log p\u03b8((cid:126)xT , (cid:126)zT ) \u2212 log q\u03c6((cid:126)zT | (cid:126)xT ) ] .\n\n(7)\n\nBy estimating \u03b8 and \u03c6 from the EHR data, we recover the generative model p\u03b8((cid:126)xT , (cid:126)zT ), through\nwhich we can extract clinical knowledge, and the inference network q\u03c6((cid:126)zT | (cid:126)xT ), through which we\ncan use to infer the health trajectory of the patient at hand.\n\n3.2 Attentive Inference Network\nWe construct the inference network q\u03c6((cid:126)zT | (cid:126)xT ) so that it mimics the structure of the true posterior\n[20]. Recall that the posterior factorizes as follows:\n\nq\u03c6((cid:126)zT | (cid:126)xT ) = q\u03c6(z1 | (cid:126)xT )\n\nq\u03c6(zt | (cid:126)\u03b1t\u22121, (cid:126)zt\u22121, (cid:126)xt:T ).\n\n(8)\n\nt=2\n\nTo capture the factorization in (8), we use the architecture in Figure 3 to construct an inference\nnetwork that mimics the attentive structure of the generative model. In this architecture, a \u201ccombiner\nfunction\u201d C(.) is fed with all the suf\ufb01cient statistics of a state zt, and outputs its posterior distribution.\nThe combiner uses the attention weights created by A to condense summary statistics of zt.\nAs dictated by (8), the parent nodes of zt are the attention weights (cid:126)\u03b1t, the previous states (cid:126)zt\u22121 and\nthe future observations (cid:126)xt:T . The inference network encodes these suf\ufb01cient statistics as follows.\nThe attention weights (cid:126)\u03b1t are shared with the attention network in Figure 2. The future observations\n(cid:126)xt:T are summarized at time t via a backward LSTM that reads (cid:126)xT in a reversed order as shown in\nFigure 3. Finally, the previous states (cid:126)zt\u22121 are sampled from the combiner functions at previous time\nsteps as described below.\n\n5\n\np\u03b8((cid:126)zT | (cid:126)xT ) = p\u03b8(z1 | (cid:126)xT )\n\np\u03b8(zt | (cid:126)\u03b1t\u22121, (cid:126)zt\u22121, (cid:126)xt:T ).\n\nConsequently, we impose a similar factorization on the inference network, i.e.,\n\nT(cid:89)\n\nt=2\n\nT(cid:89)\n\n\fPosterior sampling. In order to sample pos-\nterior state trajectories from the inference net-\nwork, we iterate over the combiner function\nC(.) for t \u2208 {1, . . ., T} as follows:\n\nq, (cid:126)\u03b1t, (\u02dcz1, . . ., \u02dczt\u22121)| \u03c0, P ),\n\n\u02dcpt = C(ht\n\u02dczt \u223c Multinomial( \u02dcpt),\n1, . ., ., \u02dcpt\n\nK),(cid:80)\n\n(9)\n\nk \u02dcpt\n\nwhere \u02dcpt = (\u02dcpt\nk = 1, is\nthe posterior state distribution estimated by\nthe inference network at time t, and ht\nq is the\ntth state of the backward LSTM in Figure 3,\nwhich summarizes the information in (cid:126)xt:T .\nAs we can see in (9), at each time step t, the\ncombiner function takes as an input all the\nprevious states (\u02dcz1, . . ., \u02dczt\u22121) sampled by ear-\nlier executions of the combiner function. The\ndashed blue lines in Figure 3 depict the pas-\nsage of older state samples to later executions\nof the combiner function.\nThe combiner function estimates the posterior \u02dcpt by emulating the state transition model in (4), i.e.,\n\nFigure 3: Attentive inference network.\n\nk,f orward =(cid:80)t\u22121\n\n\u02dcpt\n\nt(cid:48)P (\u02dczt(cid:48), k), k \u2208 {1, . . ., K},\n1,f orward. . ., \u02dcpt\n\nK,f orward],\n\nt(cid:48)=1\u03b1t\nq, \u02dcpt\n\nt\n\u02dch\nq = [ht\n\u02dcpt = Softmax(W (cid:62)\n\nq\n\nt\n\u02dch\nq + bq).\n\n(10)\n\nAs shown in (10), the combiner emulates the generative model to compute an estimate of the \u201c\ufb01ltering\u201d\nk,f orward \u2248 p\u03b8(zt | (cid:126)xt), i.e., it attends to previously sampled states with proportions\ndistribution \u02dcpt\ndetermined by the attention weights. Then, to augment information from the future observations (cid:126)xt:T ,\nit concatenates the \ufb01ltering distribution with the backward LSTM state and estimates the posterior\nthrough a Softmax output layer.\n\n3.3 Learning with Stochastic Gradient Descent\n\nIn order to simultaneously learn the parameters of the generative model and inference network, we\nuse stochastic gradient descent to solve (7) as follows:\n\n(cid:80)\n\nT ) \u223c q\u03c6((cid:126)zT | (cid:126)xT ), i = 1, . . ., N.\nT ).\n\ni (cid:96)\u03b8,\u03c6((cid:126)xT , \u02dcz(i)\n\n1 , . . ., \u02dcz(i)\n\n1. Sample (\u02dcz(i)\n1 , . . ., \u02dcz(i)\n2. Estimate ELBO \u02c6L = 1\n3. Estimate the gradients \u2207\u03b8 \u02c6L and \u2207\u03c6 \u02c6L.\n4. Update \u03c6 and \u03b8.\n\nN\n\nIn Step 2, the term (cid:96)\u03b8,\u03c6(.) denotes the objective function in (7). We estimate the gradients in Step 3\nvia stochastic backpropagation [35]. In Step 4, we use ADAM [36] to update the parameters of the\nattention mechanism (Figure 2) and the inference network (Figure 3). The emission parameters are\nupdated straightforwardly by their maximum likelihood estimates.\nRao-Blackwellization via attention. As we have seen, our attentive inference network architecture\nenables sharing parameters between the generative model and the inference model, which would\nde\ufb01nitely accelerate learning. Another key advantage of the attentive structure q\u03c6(zt | (cid:126)xT ) is that it\nacts as a Rao-Blackwellization of the conventional structured inference network which conditions on\nall observation (i.e., q\u03c6(zt | (cid:126)xT ) [20, 11, 21]). Because attention weights (together with (cid:126)zt\u22121) and\n(cid:126)xt:T )) act as suf\ufb01cient statistics for state transitions, our inference networks guides the posterior to\nfocus only on the pieces of information that matter. Rao-Blackwellization helps reduce the variance\nof gradient estimates (Step 3 in the learning algorithm above), and hence accelerate learning [37].\n\n6\n\nBackward LSTM Seq2Seq attention Combiner functions Attentive Inference network \f4 Experiments\n\nIn this Section, we use our attentive state-space framework to model cystic \ufb01brosis (CF) progression\ntrajectories. CF is a life-shortening disease that causes lung dysfunction, and is the most common\ngenetic disease in Caucasian populations [38]. Experimental details are listed hereunder.\nImplementation. We implemented our model using Tensorflow1. The LSTM cells in both the\nattention network (Figure 2) and the inference network (Figure 3) had 2 hidden layers of size 100.\nThe model and inference networks were trained using ADAM with a learning rate of 5 \u00d7 10\u22124, and a\nmini-batch size of 100. The same hyperparameters\u2019 setting was used for all baseline models involving\nRNNs. All prediction results reported in this Section where obtained via 5-fold cross-validation.\nData description. We used data from a cohort of patients enrolled in the UK CF registry, a database\nheld by the UK CF trust2. The dataset records annual follow-ups for 10,263 patients over the period\nfrom 2008 and 2015, with a total of 60,218 hospital visits. Each patient is associated with 90 variables,\nincluding information on 36 possible treatments, diagnoses for 31 possible comorbidities and 16\npossible infections, in addition to biomarkers and demographic information. The FEV1 biomarker (a\nmeasure of lung function) is the main measure of illness severity in CF patients [39].\nTraining. In Figure 4, we show the model\u2019s log-likelihood\n(LL) versus the number of training epochs. As we can see,\nthe more training iterations we apply, the better the model\nlikelihood gets: it jumped from \u22124 \u00d7 10\u22126 in the initial\niterations to \u22128 \u00d7 10\u22125 after training was completed. The\nbest value of the log-likelihood is 0, which is achieved when\nthe inference network q\u03c6(zt | (cid:126)xT ) coincides with the true\nmodel p\u03b8(zt | (cid:126)xT ), and the observed data likelihood given\nthe model is 1. Attentive inference is accurate because it\nutilizes the minimally suf\ufb01cient set if past information, which\nreduces the variance in gradient estimates (Section 3.3).\nUse cases. We assess our model with respect to the two use cases it was designed for: (1) extracting\nclinical knowledge on disease progression mechanisms from the data, and (2) predicting a patient\u2019s\nhealth trajectory over time. We assess each use case separately in Sections 4.1 and 4.2.\n\nFigure 4: LL vs. training epochs.\n\n4.1 Understanding CF Progression Mechanisms\n\nPopulation-level phenotyping. Our model learned a representation of K = 3 CF progression stages\n(Stages 1, 2 and 3) in an unsupervised fashion, i.e., each stage is a realization of the hidden state zt.\nAs we show in what follows,\neach learned progression stage\ncorresponded to a clinically\ndistinguishable phenotype of\ndisease activity. The learned\nbaseline transition matrix was\n\n(cid:34) 0.85 0.10 0.05\n\n(cid:35)\n\nP =\n\n0.13 0.72 0.15\n0.24 0.10 0.66\n\n.\n\nFigure 5: Distribution of observations in each progression stage.\n\nThe FEV1 biomarker is currently used by clinicians as a proximal measure of a patient\u2019s health in\norder to guide clinical and therapeutic decisions [40]. In order to check that the learned progression\nstages correspond to different levels of disease severity, we plot the estimated mean of the emission\ndistribution for the FEV1 biomarker in Stages 1, 2 and 3 in Figure 5 (left). As we can see from\nFigure 5 (left), the mean values of the FEV1 biomarker in each stage were 79%, 68% and 53%,\nrespectively. These values matched with the cutoff values on FEV1 used in current guidelines for\nreferring critically-ill patients to a lung transplant [40]. Thus, the learned progression stages can be\ntranslated into actionable information for clinical decision-making.\n\n1The code is provided at https://bitbucket.org/mvdschaar/mlforhealthlabpub.\n2https://www.cysticfibrosis.org.uk/the-work-we-do/uk-cf-registry/\n\n7\n\n0100200300400500Number of iterations4.03.53.02.52.01.51.00.5Log-likelihood (1e-6)DiabetesAsthmaABPAHypertensionDepression0.00.10.20.30.4Comorbidity RiskStage 1Stage 2Stage 3020406080100120FEV1 % PredictedStage 1Stage 2Stage 3\fModel\nAttentive SS\nHMM\nRNN\nLSTM\nRETAIN\n\nDiabetes\nAUC-ROC\n0.709 \u00b1 0.02\n0.625 \u00b1 0.02\n0.634 \u00b1 0.03\n0.675 \u00b1 0.03\n0.610 \u00b1 0.06\n\nABPA\n\nAUC-ROC\n0.787 \u00b1 0.01\n0.686 \u00b1 0.03\n0.727 \u00b1 0.10\n0.740 \u00b1 0.07\n0.718 \u00b1 0.05\n\nDepression\nAUC-ROC\n0.751 \u00b1 0.03\n0.667 \u00b1 0.08\n0.575 \u00b1 0.01\n0.609 \u00b1 0.12\n0.580 \u00b1 0.09\n\nPancreatitus\nAUC-ROC\n0.696 \u00b1 0.04\n0.625 \u00b1 0.04\n0.590 \u00b1 0.06\n0.578 \u00b1 0.05\n0.600 \u00b1 0.08\n\nP. Aeruginosa\n\nAUC-ROC\n0.680 \u00b1 0.01\n0.610 \u00b1 0.02\n0.654 \u00b1 0.01\n0.671 \u00b1 0.01\n0.676 \u00b1 0.02\n\nTable 2: Performance of the different competing models for the 5 prognostic tasks under consideration.\n\nThe progression stages learned by our model represented clinically distinguishable phenotypes with\nrespect to multiple clinical variables. To illustrate these phenotypes, in Figure 5 (right) we plot the\nrisks of various comorbities (Diabetes, asthma, ABPA, hypertension and depression) for patients\nin the 3 CF progression stages learned by the model. (Those risks were obtained directly from the\nt of the clinical observation\nlearned emission distribution corresponding to the binary component xb\nxt.) As we can see, the incidences of those comorbidities and infections increase signi\ufb01cantly in the\nmore severe progression Stages 2 and 3 as compared to Stage 1.\n\nIndividualized contextual diagnosis. Population level\nmodeling of disease stages can be already obtained with\nsimple HMM models, but our model captures more com-\nplex dynamics that are speci\ufb01c to individuals, and can\nbe made non-Markovian and non-stationary depending\non the patient\u2019s context. To demonstrate the complex\nand non-stationary nature of the learned state dynamics,\nwe plot the average attention weights assigned to the\npatients\u2019 previous state realizations in every \"chrono-\nlogical\" time step of a patient trajectory. The average\nattention weights per time step is plotted in Figure 6.\n\nFigure 6: Average attention weights over time.\nAs we can see, a patient\u2019s state trajectory behaves in a quasi-Markovian fashion (only current state\ntakes all the weight) only on its edges. That is, at the \ufb01rst time step and the last time step, the only\nthing that matters for prediction is the patient\u2019s current state. This is because in the \ufb01rst time step, the\npatient has no history, whereas in the \ufb01nal step, the patient is already in the most severe state and\nhence her current health deterioration depends overrides all past clinical events. Memory becomes\nimportant only in intermediate Stages \u2014 this is because patients in Stages 2 and 3 are more likely to\nhave been diagnosed with more comorbidities in the past.\n\n4.2 Predicting Prognosis\n\nAs we have seen in Section 4.2, our model is capable of extracting clinical intelligence from data,\nbut does this compromise its predictive ability? To test the predictive ability of attentive state-space\nmodels, we sequentially predict the 1-year risk of 4 comorbidities (ABPA, diabetes, depression and\npancreatitus), and 1 lung infections (Pseudomonas Aeruginosa) that are common in the CF population.\nWe use the area under receiver operating characteristic curve (AUC-ROC) for performance evaluation.\nWe report average AUC-ROC with 95% con\ufb01dence intervals. We compare our model with 4 baselines:\na vanilla RNN and an LSTM trained for sequence prediction, a state-of-the-art predictive model for\nhealthcare data known as RETAIN [17, 15], and an HMM.\nAs we can see in Table 2, our model did not incur any performance loss when compared to models\ntrained and tailored for the given prediction task (RNN, LSTM and RETAIN), and was in fact more\naccurate on all of the 5 tasks. The source of the predictive power in attentive state-space models\ncomes from the usage of LSTM networks to model state dynamics in a low-dimensional space that\nsummarizes the 90 variables associated with each patient. While HMMs can also learn interpretable\nrepresentations of disease progression, they displayed modest predictive performance because of their\noversimpli\ufb01ed Markovian dynamics. Because attentive state-space models are capable of combining\nthe interpretational bene\ufb01ts of probabilistic models and the predictive strength of deep learning, we\nenvision them being used for large-scale disease phenotyping and clinical decision-making.\n\n8\n\n012345Previous time steps012345Chronological timeMarkovian regimeNon-Markovian regimeMarkovian regime0.00.20.40.60.81.0\fAcknowledgments\n\nThis work was supported by the National Science Foundation (NSF grants 1462245 and 1533983),\nand the US Of\ufb01ce of Naval Research (ONR). The data for our experiments was provided by the UK\nCystic Fibrosis Trust. We thank Dr. Janet Allen (Director of Strategic Innovation, UK Cystic Fibrosis\nTrust) for the vision and encouragement. We thank Rebecca Cosgriff and Elaine Gunn for the help\nwith data access, extraction and analysis.\n\nReferences\n[1] Mary Ann Sevick, Jeanette M Trauth, Bruce S Ling, Roger T Anderson, Gretchen A Piatt,\nAmy M Kilbourne, and Robert M Goodman. Patients with complex chronic diseases: perspec-\ntives on supporting self-management. Journal of general internal medicine, 22(3):438\u2013444,\n2007.\n\n[2] David Blumenthal and Marilyn Tavenner. The \u201cmeaningful use\u201d regulation for electronic health\n\nrecords. New England Journal of Medicine, 363(6):501\u2013504, 2010.\n\n[3] Eric J Topol. High-performance medicine: the convergence of human and arti\ufb01cial intelligence.\n\nNature medicine, 25(1):44, 2019.\n\n[4] DW Coyne. Management of chronic kidney disease comorbidities. CKD medscape CME expert\n\ncolumn series, (3), 2011.\n\n[5] Jose M Valderas, Barbara Star\ufb01eld, Bonnie Sibbald, Chris Salisbury, and Martin Roland.\nDe\ufb01ning comorbidity: implications for understanding health and health services. The Annals of\nFamily Medicine, 7(4):357\u2013363, 2009.\n\n[6] Ioana Bica, Ahmed M Alaa, James Jordon, and Mihaela van der Schaar. Estimating counterfac-\ntual treatment outcomes over time through adversarially balanced representations. International\nConference on Learning Representations, 2020.\n\n[7] Ahmed M Alaa and Mihaela van der Schaar. A hidden absorbing semi-markov model for\ninformatively censored temporal data: Learning and inference. Journal of Machine Learning\nResearch, 2018.\n\n[8] Ilya Sutskever, Oriol Vinyals, and Quoc V Le. Sequence to sequence learning with neural\n\nnetworks. In Advances in neural information processing systems, pages 3104\u20133112, 2014.\n\n[9] Yu-Ying Liu, Shuang Li, Fuxin Li, Le Song, and James M Rehg. Ef\ufb01cient learning of continuous-\nIn Advances in neural information\n\ntime hidden markov models for disease progression.\nprocessing systems, pages 3600\u20133608, 2015.\n\n[10] Hanjun Dai, Bo Dai, Yan-Ming Zhang, Shuang Li, and Le Song. Recurrent hidden semi-markov\n\nmodel. International Conference on Learning Representations, 2016.\n\n[11] Xun Zheng, Manzil Zaheer, Amr Ahmed, Yuan Wang, Eric P Xing, and Alexander J Smola.\nState space lstm models with particle mcmc inference. arXiv preprint arXiv:1711.11179, 2017.\n\n[12] Edward Choi, Mohammad Taha Bahadori, Andy Schuetz, Walter F Stewart, and Jimeng Sun.\nDoctor ai: Predicting clinical events via recurrent neural networks. In Machine Learning for\nHealthcare Conference, pages 301\u2013318, 2016.\n\n[13] Zachary C Lipton, David C Kale, Charles Elkan, and Randall Wetzel. Learning to diagnose\nwith lstm recurrent neural networks. International Conference on Learning Representations,\n2016.\n\n[14] Bryan Lim and Mihaela van der Schaar. Disease-atlas: Navigating disease trajectories with\n\ndeep learning. Machine Learning for Healthcare Conference (MLHC), 2018.\n\n[15] Edward Choi, Mohammad Taha Bahadori, Jimeng Sun, Joshua Kulas, Andy Schuetz, and Walter\nStewart. Retain: An interpretable predictive model for healthcare using reverse time attention\nmechanism. In Advances in Neural Information Processing Systems, pages 3504\u20133512, 2016.\n\n9\n\n\f[16] Fenglong Ma, Radha Chitta, Jing Zhou, Quanzeng You, Tong Sun, and Jing Gao. Dipole:\nDiagnosis prediction in healthcare via attention-based bidirectional recurrent neural networks.\nIn Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery\nand Data Mining, pages 1903\u20131911. ACM, 2017.\n\n[17] Bum Chul Kwon, Min-Je Choi, Joanne Taery Kim, Edward Choi, Young Bin Kim, Soonwook\nKwon, Jimeng Sun, and Jaegul Choo. Retainvis: Visual analytics with interpretable and\ninteractive recurrent neural networks on electronic medical records. IEEE transactions on\nvisualization and computer graphics, 25(1):299\u2013309, 2019.\n\n[18] Xiang Wang, David Sontag, and Fei Wang. Unsupervised learning of disease progression\nmodels. In Proceedings of the 20th ACM SIGKDD international conference on Knowledge\ndiscovery and data mining, pages 85\u201394. ACM, 2014.\n\n[19] Ahmed M Alaa, Scott Hu, and Mihaela van der Schaar. Learning from clinical judgments:\nSemi-markov-modulated marked hawkes processes for risk prognosis. International Conference\non Machine Learning, 2017.\n\n[20] Rahul G Krishnan, Uri Shalit, and David Sontag. Structured inference networks for nonlinear\n\nstate space models. In AAAI, pages 2101\u20132109, 2017.\n\n[21] Maximilian Karl, Maximilian Soelch, Justin Bayer, and Patrick van der Smagt. Deep variational\nbayes \ufb01lters: Unsupervised learning of state space models from raw data. arXiv preprint\narXiv:1605.06432, 2016.\n\n[22] Matthew Johnson, David K Duvenaud, Alex Wiltschko, Ryan P Adams, and Sandeep R Datta.\nComposing graphical models with neural networks for structured representations and fast\ninference. In Advances in neural information processing systems, pages 2946\u20132954, 2016.\n\n[23] Syama Sundar Rangapuram, Matthias W Seeger, Jan Gasthaus, Lorenzo Stella, Yuyang Wang,\nand Tim Januschowski. Deep state space models for time series forecasting. In Advances in\nNeural Information Processing Systems, pages 7796\u20137805, 2018.\n\n[24] Junyoung Chung, Kyle Kastner, Laurent Dinh, Kratarth Goel, Aaron C Courville, and Yoshua\nBengio. A recurrent latent variable model for sequential data. In Advances in neural information\nprocessing systems, pages 2980\u20132988, 2015.\n\n[25] Marco Fraccaro, S\u00f8ren Kaae S\u00f8nderby, Ulrich Paquet, and Ole Winther. Sequential neural\nmodels with stochastic layers. In Advances in neural information processing systems, pages\n2199\u20132207, 2016.\n\n[26] Justin Bayer and Christian Osendorfer. Learning stochastic recurrent networks. arXiv preprint\n\narXiv:1411.7610, 2014.\n\n[27] Allison A Eddy and Eric G Neilson. Chronic kidney disease progression. Journal of the\n\nAmerican Society of Nephrology, 17(11):2964\u20132966, 2006.\n\n[28] Frans MJ Willems, Yuri M Shtarkov, and Tjalling J Tjalkens. The context-tree weighting\n\nmethod: basic properties. IEEE Transactions on Information Theory, 41(3):653\u2013664, 1995.\n\n[29] Ron Begleiter, Ran El-Yaniv, and Golan Yona. On prediction using variable order markov\n\nmodels. Journal of Arti\ufb01cial Intelligence Research, 22:385\u2013421, 2004.\n\n[30] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly\nlearning to align and translate. International Conference on Learning Representations, 2015.\n\n[31] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,\n\u0141ukasz Kaiser, and Illia Polosukhin. Attention is all you need. In Advances in Neural Informa-\ntion Processing Systems, pages 6000\u20136010, 2017.\n\n[32] Peter J Green and Sylvia Richardson. Hidden markov models and disease mapping. Journal of\n\nthe American statistical association, 97(460):1055\u20131070, 2002.\n\n[33] Andriy Mnih and Karol Gregor. Neural variational inference and learning in belief networks.\n\narXiv preprint arXiv:1402.0030, 2014.\n\n10\n\n\f[34] Diederik P Kingma and Max Welling. Auto-encoding variational bayes. International Confer-\n\nence on Learning Representations, 2014.\n\n[35] John Schulman, Nicolas Heess, Theophane Weber, and Pieter Abbeel. Gradient estimation\nusing stochastic computation graphs. In Advances in Neural Information Processing Systems,\npages 3528\u20133536, 2015.\n\n[36] Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint\n\narXiv:1412.6980, 2014.\n\n[37] Rajesh Ranganath, Sean Gerrish, and David Blei. Black box variational inference. In Arti\ufb01cial\n\nIntelligence and Statistics, pages 814\u2013822, 2014.\n\n[38] Rhonda D Szczesniak, Dan Li, Weiji Su, Cole Brokamp, John Pestian, Michael Seid, and John P\nClancy. Phenotypes of rapid cystic \ufb01brosis lung disease progression during adolescence and\nyoung adulthood. American journal of respiratory and critical care medicine, 196(4):471\u2013478,\n2017.\n\n[39] Don B Sanders, Lucas R Hoffman, Julia Emerson, Ronald L Gibson, Margaret Rosenfeld,\nGregory J Redding, and Christopher H Goss. Return of fev1 after pulmonary exacerbation in\nchildren with cystic \ufb01brosis. Pediatric pulmonology, 45(2):127\u2013134, 2010.\n\n[40] Andrew T Braun and Christian A Merlo. Cystic \ufb01brosis lung transplantation. Current opinion\n\nin pulmonary medicine, 17(6):467\u2013472, 2011.\n\n11\n\n\f", "award": [], "sourceid": 6051, "authors": [{"given_name": "Ahmed", "family_name": "Alaa", "institution": "UCLA"}, {"given_name": "Mihaela", "family_name": "van der Schaar", "institution": "University of Cambridge, Alan Turing Institute and UCLA"}]}