van der Schaar Lab

Survival analysis, competing risks, and comorbidities


Machine learning is capable of enabling truly personalized healthcare; this is what our lab calls “bespoke medicine.”

More information on bespoke medicine can be found here.

Survival analysis (the analysis of data where the target variable is the time to the occurrence of a certain event) is one of several key prognostic areas that are essential to making bespoke medicine a reality: effective survival analysis models can enable clinicians to estimate and design bespoke survival prediction estimates. This, in turn, can inform the development of bespoke treatment plans for each patient at any point in time.

To achieve this, however, survival analysis models need to be accurate, trustworthy, and interpretable. They must also take into account competing risks, comorbidities, and multiple events, and be capable of operating in static and dynamic settings.

Creating such a model is extremely challenging, and this is further complicated by the fact that no single model will ever offer the best performance across all datasets (or even within a single dataset). Evaluation of model performance represents a further problem in need of addressing.

The page below provides an introduction to survival analysis with competing risks and comorbidities, as well as an overview of some of our own lab’s key projects that have driven the entire research area forward.

If you’re interested in learning more, you can find our publications on survival analysis, competing risks, and comorbidities, here.

Table Of Contents

This page is one of several introductions to areas that we see as “research pillars” for our lab. It is a living document, and the content here will evolve as we continue to reach out to the machine learning and healthcare communities, building a shared vision for the future of healthcare.

Our primary means of building this shared vision is through two groups of online engagement sessions: Inspiration Exchange (for machine learning students) and Revolutionizing Healthcare (for the healthcare community). If you would like to get involved, please visit the page below.

This page is authored and maintained by Mihaela van der Schaar and Nick Maxfield.


Survival analysis: individualized predictions supporting bespoke medicine

Survival analysis (often referred to as time-to-event analysis) refers to the study of the duration until one or more events occur. This is vital to a great many predictive tasks across numerous fields of application, including economics, finance, and engineering—and, of course, healthcare.

A long and diverse literature approaches survival analysis by viewing the event of interest as the first event time (or “hitting time”) of an underlying stochastic process, i.e. the first time at which the stochastic process reaches a prescribed boundary. Depending on the context, the first event time may represent the time until a stock option can profitably be exercised, or the time to failure of a mechanical system. In healthcare, the objective is principally to model the probability of a certain event occurring (such as a heart attack) as a function of the covariates.

A fundamental problem of survival analysis in healthcare and beyond is the need to understand the relationships between the (distribution of) event times and the covariates, such as the features of an individual patient. In the medical setting, survival analysis is often applied to the discovery of risk factors affecting survival, comparison among risks of different subjects at a certain time of interest, and decisions related to cost-efficient acquisition of information (e.g. screening for cancer).

The problem of survival analysis with competing risks has gained significant attention in the medical community due to the realization that many chronic diseases possess a shared biology, as we will explain in depth below.

This is a particularly important but complicated aspect of our ongoing drive to make bespoke medicine a reality, fully leveraging the power of machine learning to create view of the individual that is both a holistic (spanning all available medical history and complementary data, and generating comprehensive arrays of risk scores and predictions) and dynamic (factoring in time-series data and evolving over time to provide lifetime care).

Diagnostic and treatment decisions are still too often made based on the “average” patient. Understandably, the relationship of a particular treatment or condition to survival can vary greatly within populations and is important in guiding clinical decisions and understanding the disease. By contrast, accurate and informative models can enable clinicians to estimate and design bespoke survival prediction estimates, which in turn can inform the development of bespoke treatment plans for each patient at any point in time.

In the summary below, we will be introducing a range of methods and approaches developed with healthcare in mind, although all of these can be applied more broadly. We will start by sharing some fundamental requirements for successful survival analysis models, and will then introduce the key challenges of building such models in static and temporal settings, before finally discussing metrics for model evaluation.

Survival analysis for healthcare: core requirements

The importance of survival analysis has prompted the development of a variety of approaches to model the survival function.

The Cox proportional hazards model (first introduced in 1972) is still the most widely-used model in the medical setting, but it makes many strong assumptions about the underlying stochastic process and about the relationship between the covariates and the parameters of that process.

While machine learning techniques (neural networks in particular) have been used to model non-linear representations for the relation between covariates and the risk of an event, these approaches generally still maintain the basic assumptions of the Cox model, weakening only the assumption of the form of the relationship between covariates and hazard rate.

Parametric and semiparametric approaches construct models that rely on specific assumptions about the true underlying distribution; non-parametric approaches take a more agnostic point of view to construct models that rely on (variants of) familiar machine learning methods. The models produced by these various approaches offer different strengths and weaknesses in terms of both discriminative performance and calibration, and their relative performance varies across different datasets and at different time horizons within a single dataset.

In general, a good survival analysis model should be accurate, trustworthy, and interpretable. More specifically, the usefulness of a model should be assessed both by how well the model discriminates among predicted risks and by how well the model is calibrated.

This can easily be seen using the example of organ transplantation. Successful transplantation can mean many additional years of life for such patients, but (given the mismatch between organ supply and demand) it is important to correctly discriminate/prioritize recipients on the basis of risk. However, if the risk predictions of a given model are not well calibrated to the truth (i.e. if there is poor agreement between predicted and observed outcomes) the model will have little prognostic value for clinicians.

Accounting for censoring

The defining characteristic of survival analysis (and what separates it from other areas of data analysis) is that time-to-event is often observed subject to censoring.

Time-to-event (survival) data provides three pieces of information for each subject: i) observed covariates, ii) time-to-event (s), and iii) a label indicating the type of event (e.g., death or adverse clinical event). The last of these includes right-censoring, which is the most common form of censoring that is observed in most survival analysis settings. Right-censoring occurs when either the subject is lost in the follow-up, or the study ends.

Survival analysis - censoring

Simplified depiction of a time-to-event dataset composed of subjects who have been followed for a finite amount of time. The set K = {∅, 1, . . ., K} is a finite set of K mutually exclusive competing events that could occur, where ∅ corresponds to right-censoring.

Censoring precludes us from observing the event-time but does tell us that the patient had not experienced the event (and, therefore, was alive) up until the event-time. Beyond this, we generally assume that censoring is independent. This assumption is common in the survival literature, and implies that whether a subject withdraws from a study depends only on the observed history, rather than on the clinical outcomes.

Censoring can introduce bias in a population that is observed at different times in the follow-up. Importantly, the presence of censoring also means that the goal of survival analysis problems, unlike classification or regression problems, must be to produce a valid cumulative incidence function (or equivalently, a survival function) for a type of event.

Survival analysis - cumulative incidence function

All good survival analysis models must account for censoring. This is a feature of all of our lab’s models featured on this page; while the precise method of handling censoring may not be explicitly stated in the summaries provided below, full details are available in the papers themselves.

How can we craft the best possible models for survival analysis?

As is generally the case throughout machine learning, no single model is best across all datasets, and frequently no single model is best across all time horizons within a single dataset. This presents a challenge to familiar methods of model selection or ensemble creation.

An additional challenge is that survival analysis needs to yield good performance at different time horizons while providing a valid and well-calibrated survival function; this makes the conventional model selection or ensemble methods actually inapplicable.

This challenge prompted our lab to take an ensemble approach to survival analysis, making use of our extensive work in the area of automated machine learning (AutoML).

Our novel approach, known as SurvivalQuilts (introduced in a paper for AISTATS 2019), uses automated machine learning to combine the collective intelligence of different underlying survival models to produce a valid survival function that is well-calibrated and offers superior discriminative performance at different time horizons. SurvivalQuilts pieces together existing survival analysis models according to endogenously determined, time-varying weights (for more information, please see the SurvivalQuilts paper linked to below). We refer to this kind of construction as temporal quilting, and to the resultant model as a survival quilt.

Survival analysis - temporal quilting

An example of temporal quilting with prescribed weights for a range of survival models (COX, RISF, CISF) at t1, t2, and t3. A risk function is constructed by stitching together the weighted increment functions of each survival model between two adjacent time horizons.

The core part of our method is an algorithm for configuring the weights sequentially over a (perhaps very fine) grid of time intervals. To render the problem tractable, we apply constrained Bayesian Optimization (BO), which models the discrimination and calibration performance metrics as black-box functions, whose input is an array of weights (over different time horizons) and whose output is the corresponding performance achieved. Based on the constructed array of weights, our method makes a single predictive model—a survival quilt—that provides a valid risk function.

Temporal quilting for survival analysis

Changhee Lee, William Zame, Ahmed Alaa, Mihaela van der Schaar

AISTATS 2019

The importance of survival analysis in many disciplines (especially in medicine) has led to the development of a variety of approaches to modeling the survival function. Models constructed via various approaches offer different strengths and weaknesses in terms of discriminative performance and calibration, but no one model is best across all datasets or even across all time horizons within a single dataset.

Because we require both good calibration and good discriminative performance over different time horizons, conventional model selection and ensemble approaches are not applicable.

This paper develops a novel approach that combines the collective intelligence of different underlying survival models to produce a valid survival function that is well-calibrated and offers superior discriminative performance at different time horizons. Empirical results show that our approach provides significant gains over the benchmarks on a variety of real-world datasets.

The following introduction to SurvivalQuilts is from Mihaela van der Schaar’s her keynote at the ICML 2020 AutoML workshop.

SurvivalQuilts was recently used to develop a prognostic model for prostate cancer, and showed an incremental net benefit by comparison with other models currently used for prostate cancer. Details are available in an article published in The Lancet Digital Health.

Competing risks

In the static setting (as opposed to the dynamic setting, introduced below), survival analysis methods produce predictions over time for an individual based on data available at a single moment—for example, a “snapshot” of information about a patient at the point of first diagnosis or initial admission to hospital.

While static prediction may sound less complicated than dynamic prediction, challenges still abound. In particular, designing optimal treatment plans for elderly patients or patients with comorbidities is an inherently challenging problem, whether in the static or dynamic setting: the nature (and the appropriate level of invasiveness) of the best therapeutic intervention for a patient with a specific clinical risk depends on whether this patient suffers from, or is susceptible to other “competing risks.” For example, Encounters with competing risks occur frequently in oncology and cardiovascular medicine, where the risk of a cardiac disease may determine whether a cancer patient should undergo chemotherapy or a particular type of surgery.

An example of a survival analysis model’s combined risk profile for multiple adverse outcomes (onset of cardiovascular disease and breast cancer) over time.

Conventional methods for survival analysis, such as the Kaplan-Meier method and standard Cox proportional hazards regression, are not equipped to handle competing risks. While alternate variants of those methods that rely on cumulative incidence estimators have been proposed and used in clinical research, machine learning models (powered by high-quality electronic health records and complementary datasets) can provide predictions that offer superior accuracy and quality of insight.

One such example would be a nonparametric Bayesian model for survival analysis with competing risks using deep multi-task Gaussian processes (DMGPs), first introduced in a paper published at NeurIPS 2017.

Our DMGP model relies on a novel conception of the competing risks problem as a multi-task learning problem; that is, we model the cause-specific survival times as the outputs of a random vector-valued function, the inputs to which are the patients’ covariates. This allows us to learn a “shared representation” of the patients’ survival times with respect to multiple related comorbidities.

Unlike many existing parametric survival models, our DMGP model neither assumes a parametric form for the interactions between the covariates and the survival times, nor does it restrict the distribution of the survival times to a parametric model.

Survival analysis - deep multi-task Gaussian processes

Survival analysis with 2 competing risks using DMGP. The posterior distribution of T given D is displayed in the top left panel, and the corresponding cumulative incidence functions for a particular patient with covariates X* is displayed in the bottom left.

The posterior distributions on the two DMGP layers conditional on their inputs are depicted on the right panels.

Deep multi-task Gaussian processes for survival analysis with competing risks

Ahmed Alaa, Mihaela van der Schaar

NeurIPS 2017

Designing optimal treatment plans for patients with comorbidities requires accurate cause-specific mortality prognosis.

Motivated by the recent availability of linked electronic health records, we develop a nonparametric Bayesian model for survival analysis with competing risks, which can be used for jointly assessing a patient’s risk of multiple (competing) adverse outcomes.

The model views a patient’s survival times with respect to the competing risks as the outputs of a deep multi-task Gaussian process (DMGP), the inputs to which are the patients’ covariates. Unlike parametric survival analysis methods based on Cox and Weibull models, our model uses DMGPs to capture complex non-linear interactions between the patients’ covariates and cause-specific survival times, thereby learning flexible patient-specific and cause-specific survival curves, all in a data-driven fashion without explicit parametric assumptions on the hazard rates. We propose a variational inference algorithm that is capable of learning the model parameters from time-to-event data while handling right censoring.

Experiments on synthetic and real data show that our model outperforms the state-of-the-art survival models.

In 2018, our lab introduced DeepHit, a somewhat different approach to survival analysis that uses a deep neural network to learn the distribution of survival times directly.

DeepHit employs a network architecture that consists of a single shared sub-network and a family of cause-specific sub-networks. We trained the network by using a loss function that exploits both survival times and relative risks. DeepHit makes no assumptions about the form of the underlying stochastic process; it therefore allows for the possibility that, even for a fixed cause or causes (e.g. a disease or diseases), both the parameters and the form of the stochastic process depend on the covariates.

While other approaches based on neural networks have maintained the basic assumptions of the Cox model, weakening only the assumption of the form of the relationship between covariates and the hazard rate, DeepHit allows for the possibility that the relationship between covariates and risk(s) changes over time.

Furthermore, DeepHit (like DMGP) handles both single risks and competing risks.

DeepHit: a deep learning approach to survival analysis with competing risks

Changhee Lee, William Zame, Jinsung Yoon, Mihaela van der Schaar

AAAI 2018

Survival analysis (time-to-event analysis) is widely used in economics and finance, engineering, medicine and many other areas. A fundamental problem is to understand the relationship between the covariates and the (distribution of) survival times(times-to-event).

Much of the previous work has approached the problem by viewing the survival time as the first hitting time of a stochastic process, assuming a specific form for the underlying stochastic process, using available data to learn the relationship between the covariates and the parameters of the model, and then deducing the relationship between covariates and the distribution of first hitting times (the risk). However, previous models rely on strong parametric assumptions that are often violated.

This paper proposes a very different approach to survival analysis, DeepHit, that uses a deep neural network to learn the distribution of survival times directly.

DeepHit makes no assumptions about the underlying stochastic process and allows for the possibility that the relationship between covariates and risk(s) changes over time. Most importantly, DeepHit smoothly handles competing risks; i.e. settings in which there is more than one possible event of interest.

Comparisons with previous models on the basis of real and synthetic datasets demonstrate that DeepHit achieves large and statistically significant performance improvements over previous state-of-the-art methods.

In another model introduced in 2018, we took a different approach, conceptualizing the competing risks problem as a mixture of competing survival trajectories with latent variables determining the weight of these different but related trajectories.

In this case, the parameters of the cause-specific distributions and assignment variables were modelled jointly with multivariate random forest (MRF), allowing us to learn a “shared representation” of patient survival times with respect to multiple related co-morbidities and allowing for nonlinear covariate influences.

Naturally, such a process gives rise to patient-specific survival distribution, from which a patient-specific, cause-related cumulative incidence function can be easily derived.

Tree-based Bayesian mixture model for competing risks

Alexis Bellot, Mihaela van der Schaar

AISTATS 2018

Many chronic diseases possess a shared biology. Therapies designed for patients at risk of multiple diseases need to account for the shared impact they may have on related diseases to ensure maximum overall well-being. Learning from data in this setting differs from classical survival analysis methods since the incidence of an event of interest may be obscured by other related competing events.

We develop a semi-parametric Bayesian regression model for survival analysis with competing risks, which can be used for jointly assessing a patient’s risk of multiple (competing) adverse outcomes. We construct a Hierarchical Bayesian Mixture (HBM) model to describe survival paths in which a patient’s covariates influence both the estimation of the type of adverse event and the subsequent survival trajectory through Multivariate Random Forests. In addition variable importance measures, which are essential for clinical interpretability are induced naturally by our model. We aim with this setting to provide accurate individual estimates but also interpretable conclusions for use as a clinical decision support tool.

We compare our method with various state-of-the-art benchmarks on both synthetic and clinical data.

The last major project we will highlight related to the static setting is multitask boosting, an approach introduced in a paper published at NeurIPS 2018.

With the aim of developing a flexible simultaneous description of the likelihood of different events over time by estimating full probability distributions, we specifically sought to leverage the heterogeneity present in large modern datasets, the complexity in underlying relationships between events/tasks, and the strong imbalance often observed between events/tasks.

To do this, we developed a boosting algorithm in which each task-specific time-to-event distribution is a component of a multi-output function. A distinctive feature is that each weak estimator (whose performance is sub-optimal) learns a shared representation between tasks by recursively partitioning the observed data (analogous to the construction of trees) from all related tasks using a measure of similarity between instances that involves all related tasks.

This means that we learn a shared representation directly from selecting appropriate subsets of patients, who may experience different events, but which have a common time-to-event trajectory.

Our algorithm is general and represents the first boosting-like method for time-to-event data with multiple outcomes.

Survival analysis - multitask boosting

In the figure above, each patient is characterized by their body mass index (BMI), cholesterol level and age at menarche. The medical fact is that increased BMI increases the risk of both breast cancer and CVD; increased cholesterol increases the risk of CVD but is irrelevant for breast cancer; increased age at menarche decreases the risk of breast cancer but is irrelevant for CVD. (Note that the same patients are represented in all panels – the vertical position remains the same while their horizontal position changes due to different features being considered).

The panels show three iterations of boosting using a stump as a weak predictor; the best partition of the data in each case is shown with the yellow threshold.

The first stump recognizes BMI as best separating event times (on average), but mispredicts survival of patient (a) (who has high survival despite having high BMI) and (b) for whom the contrary is true.

Iteration 2, encouraged by the higher weight of (a), considers a split along the cholesterol level and is able to better describe (a)’s survival (high survival due to a low cholesterol level).

Iteration 3, after repeatedly mispredicting (b) in iterations 1 and 2, splits based on age at menarche which explains (b)’s low survival.

Multitask boosting for survival analysis with competing risks

Alexis Bellot, Mihaela van der Schaar

NeurIPS 2018

The co-occurrence of multiple diseases among the general population is an important problem as those patients have more risk of complications and represent a large share of health care expenditure.

Learning to predict time-to-event probabilities for these patients is a challenging problem because the risks of events are correlated (there are competing risks) with often only few patients experiencing individual events of interest, and of those only a fraction are actually observed in the data.

We introduce in this paper a survival model with the flexibility to leverage a common representation of related events that is designed to correct for the strong imbalance in observed outcomes. The procedure is sequential: outcome-specific survival distributions form the components of nonparametric multivariate estimators which we combine into an ensemble in such a way as to ensure accurate predictions on all outcome types simultaneously.

Our algorithm is general and represents the first boosting-like method for time-to-event data with multiple outcomes. We demonstrate the performance of our algorithm on synthetic and real data.

Dynamic prediction of competing risks, comorbidities, and multiple events

As described above, many illnesses arise not just from individual causes for a specific disease, but as a complex interaction between other diseases or conditions that a patient may already have had. For example, long-term diabetes increases the risk of cardiovascular and renal disease, making high blood pressure and its complications (such as heart attacks) more likely.

Comorbid diseases co-occur and progress via complex temporal patterns that vary among individuals. Accordingly, identifying and understanding the contribution of comorbidities to disease progression and outcomes is fundamental to medicine and clinical practice. This is why predictions in the dynamic setting, which factor in longitudinal measurements spanning the entire medical history of a patient, are particularly important.

One key limitation of existing survival models is that they utilize only a small fraction of the available longitudinal (repeated) measurements of biomarkers and other risk factors. In particular, even though biomarkers and other risk factors are measured repeatedly over time, survival analysis is typically based on the last available measurement. This represents a severe limitation, since the evolution of biomarkers and risk factors has been shown to be informative in predicting the onset of disease and various risks.

While attempts have been made to overcome these shortcomings with approaches including landmarking and joint models, such approaches are not without their own limitations. Landmarking is “partially conditional,” since each survival model is conditioned on the available information accrued by the pre-specified landmarking time, rather than incorporating the entire longitudinal history, and predictions on survival probabilities are typically issued using the last measurements as an estimate of biomarkers at these landmarking times. Joint models, meanwhile, suffer from model misspecifications (i.e. the assumption on the longitudinal process and proportional hazard assumption on time-to-event), which limits overall performance, and the optimization of the joint likelihood involves severe computational challenges when applied to high-dimensional datasets.

Shifting from the static setting to the dynamic setting

To provide a better understanding of disease progression, it is essential to incorporate longitudinal measurements of biomarkers and risk factors into a model. Rather than discarding valuable information recorded over time, this allows us to make better risk assessments on the clinical events. This is why our lab developed Dynamic-DeepHit, an extension of the work that went into the DeepHit model introduced above.

While inheriting the neural network structure of its predecessor and maintaining the ability to handle competing risks, Dynamic-DeepHit learns, on the basis of the available longitudinal measurements, a data-driven distribution of first event times of competing events. This completely removes the need for explicit model specifications (i.e., no assumptions about the form of the underlying stochastic processes are made) and enables us to learn the complex relationships between trajectories and survival probabilities.

To enable dynamic survival analysis with longitudinal time-to-event data, Dynamic-DeepHit employs a shared subnetwork and a family of cause-specific subnetworks. The shared subnetwork encodes the information in longitudinal measurements into a fixed-length vector (i.e., a context vector) using a recurrent neural network (RNN).

A temporal attention mechanism is employed in the hidden states of the RNN structure when constructing the context vector. This allows Dynamic-DeepHit to access the necessary information, which has progressed along with the trajectory of the past longitudinal measurements, by paying attention to relevant hidden states across different time stamps. Then, the cause-specific subnetworks take the context vector and the last measurements as an input and estimate the joint distribution of the first event time and competing events that is further used for risk predictions.

Survival analysis - dynamic deephit cimulative incidence function

As shown above, Dynamic-DeepHit updates its survival predictions (presented as cumulative incidence functions) as new observations are collected over time.

Gray solid lines, yellow dotted lines, and stars indicate times at which measurement are taken, the time at which a patient is censored, and the time at which an event occurred, respectively.

Dynamic-DeepHit: a deep learning approach for dynamic survival analysis with competing risks
based on longitudinal data

Changhee Lee, Jinsung Yoon, Mihaela van der Schaar

IEEE Transactions on Biomedical Engineering, 2019

Currently available risk prediction methods are limited in their ability to deal with complex, heterogeneous, and longitudinal data such as that available in primary care records, or in their ability to deal with multiple competing risks.

This paper develops a novel deep learning approach that is able to successfully address current limitations of standard statistical approaches such as land marking and joint modeling. Our approach, which we call Dynamic-DeepHit, flexibly incorporates the available longitudinal data comprising various repeated measurements (rather than only the last available measurements) in order to issue dynamically updated survival predictions for one or multiple competing risk(s).

Dynamic-DeepHit learns the time-to-event distributions without the need to make any assumptions about the underlying stochastic models for the longitudinal and the time-to-event processes. Thus, unlike existing works in statistics, our method is able to learn data-driven associations between the longitudinal data and the various associated risks without underlying model specifications.

We demonstrate the power of our approach by applying it to a real-world longitudinal dataset from the U.K. Cystic Fibrosis Registry, which includes a heterogeneous cohort of 5883 adult patients with annual follow-ups between 2009 to 2015. The results show that Dynamic-DeepHit provides a drastic improvement in discriminating individual risks of different forms of failures due to cystic fibrosis.

Furthermore, our analysis utilizes post-processing statistics that provide clinical insight by measuring the influence of each covariate on risk predictions and the temporal importance of longitudinal measurements, thereby enabling us to identify covariates that are influential for different competing risks.

Integrating interpretable stages into patient trajectories

Chronic diseases such as cardiovascular disease, cancer, and diabetes progress slowly throughout a patient’s lifetime. This progression can be segmented into “stages” that that manifest through clinical observations.

Almost all existing models of disease progression are based on variants of the hidden Markov model (HMM). Disease dynamics in such models are very easily interpretable as they can be perfectly summarized through a single matrix of probabilities that describes transition rates among disease states. Markovian dynamics also simplify inference because the model likelihood factorizes in a way that makes efficient forward and backward message passing possible.

However, memoryless Markov models assume that a patient’s current state separates their future trajectory from their clinical history. This renders HMM-based models incapable of properly explaining the heterogeneity in patients’ progression trajectories, which often result from their varying clinical histories or the chronologies (timing and order) of their clinical events. This is a crucial limitation in survival analysis models for complex chronic diseases that are accompanied with multiple morbidities.

One of our lab’s approaches aiming to overcome such limitations is attentive state-space modeling (ASSM), first introduced in a paper published at NeurIPS 2019. ASSM was developed to learn accurate and interpretable structured representations for disease trajectories, and offers a deep probabilistic model of disease progression that capitalizes on both the interpretable structured representations of probabilistic models and the predictive strength of deep learning methods.

Unlike conventional Markovian state-space models, ASSM uses recurrent neural networks (RNNs) to capture more complex state dynamics. Since it learns hidden disease states from observational data in an unsupervised fashion, ASSM is well-suited to EHR data, where a patient’s record is seldom annotated with “labels” indicating their true health state.

As implied by the name, ASSM captures state dynamics through an attention mechanism, which observes the patient’s clinical history and maps it to attention weights that determine how much influence previous disease states have on future state transitions. In that sense, attention weights generated for an individual patient explain the causative and associative relationships between the hidden disease states and the past clinical events for that patient. Since attentive state-space models combine the interpretational benefits of probabilistic models and the predictive strength of deep learning, we envision them being used for large-scale disease phenotyping and clinical decision-making.

Survival analysis - attentive state-space model

The figure above demonstrates modeling of cystic fibrosis progression trajectories using ASSM. The model learns progression stages (stages 1, 2 and 3) in an unsupervised fashion, with each learned progression stage corresponding to a clinically distinguishable phenotype of disease activity.

On the left-hand side of the figure, the estimated mean of the FEV1 biomarker (currently used by clinicians as a proximal measure of a patient’s health) is plotted in ASSM’s learned stages 1, 2 and 3. It is clear that each learned progression stage corresponds to a clinically distinguishable phenotype of disease activity, meaning , the learned progression stages can be translated into actionable information for clinical decision-making.

To illustrate these phenotypes, the right-hand side of the figure plots the risks of various comorbidities (diabetes, asthma, ABPA, hypertension and depression) for patients in the 3 progression stages learned by the model. As we can see, the incidences of those comorbidities and infections increase significantly in the more severe progression stages 2 and 3, as compared to stage 1.

ASSM also features a structured inference network trained to predict posterior state distributions by mimicking the attentive structure of our model. The inference network shares attention weights with the generative model, and uses those weights to create summary statistics needed for posterior state inference.

To the best of our knowledge, ASSM is the first deep probabilistic model that provides clinically meaningful latent representations, with non-Markovian state dynamics that can be made arbitrarily complex while remaining interpretable.

Attentive state-space modeling of disease progression

Ahmed Alaa, Mihaela van der Schaar

NeurIPS 2019

Models of disease progression are instrumental for predicting patient outcomes and understanding disease dynamics. Existing models provide the patient with pragmatic (supervised) predictions of risk, but do not provide the clinician with intelligible (unsupervised) representations of disease pathophysiology.

In this paper, we develop the attentive state-space model, a deep probabilistic model that learns accurate and interpretable structured representations for disease trajectories. Unlike Markovian state-space models, in which the dynamics are memoryless, our model uses an attention mechanism to create “memoryful” dynamics, whereby attention weights determine the dependence of future disease states on past medical history.

To learn the model parameters from medical records, we develop an infer ence algorithm that simultaneously learns a compiled inference network and the model parameters, leveraging the attentive state-space representation to construct a “Rao-Blackwellized” variational approximation of the posterior state distribution.

Experiments on data from the UK Cystic Fibrosis registry show that our model demonstrates superior predictive accuracy and provides insights into the progression of chronic disease.

Providing uncertainty estimates

As mentioned at the top of this page, a good model must be accurate, trustworthy, and interpretable. While our exploration of the methods presented above has covered accuracy and interpretability so far, one of our lab’s projects has achieved particular advances in the area of uncertainty quantification—a vital aspect of trustworthiness. This is particularly important since many models do not account for the uncertainty of model predictions; for example, RNNs typically only produce single point forecasts.

A substantial portion of machine learning research now investigates a prognosis with time series data, typically focusing on patients in the hospital where information is densely collected. Predictions in this setting tend to target a binary event and require structured and regularly sampled data. Modeling survival data differs from the preceding, as data is often recorded with censoring—a patient may drop out of a study, but information up to that time is known. This requires specialized methods, as labels will not be available for every patient. Moreover, the sparsity of recorded information will tend to lead to substantial uncertainty about model predictions, which needs to be captured for reliable inference in practice.

In a paper published in ACM Transactions on Computing for Healthcare in 2020, our lab discussed these issues and presented a Bayesian nonparametric dynamic survival (BNDS) aiming to overcome them by (1) quantifying the uncertainty around model predictions in a principled manner at the individual level while (2) avoiding making assumptions on the data generating process and adapting model complexity to the structure of the data. Both contributions are particularly important for personalizing health care decisions, as predictions may be uncertain due to lack of data, whereas we can expect the underlying heterogeneous physiological time series to vary wildly across patients.

The BNDS model is generative, defined by a prior distribution on the event probability and on associations between variables and time, which we model with Bayesian additive regression trees, and by a likelihood over observed events. We refrain from specifying a priori the longitudinal process, interactions with other variables, or shape of survival curves by augmenting the dimensionality of the data and framing the problem in a discrete survival setting, which as a result leads to a fully data-driven Bayesian nonparametric model.

BNDS can use sparse longitudinal data to give personalized survival predictions that are updated as new information is recorded. Our approach has the advantage of not imposing any constraints on the data generating process, which, together with novel postprocessing statistics, expands the capabilities of current methods.

Flexible modelling of longitudinal medical data: a Bayesian nonparametric approach

Alexis Bellot, Mihaela van der Schaar

ACM Transactions on Computing for Healthcare, 2020

Using electronic medical records to learn personalized risk trajectories poses significant challenges because often very few samples are available in a patient’s history, and, when available, their information content is highly diverse.

In this article, we consider how to integrate sparsely sampled longitudinal data, missing measurements informative of the underlying health status, and static information to estimate (dynamically, as new information becomes available) personalized survival distributions.

We achieve this by developing a nonparametric probabilistic model that generates survival trajectories, and corresponding uncertainty estimates, from an ensemble of Bayesian trees in which time is incorporated explicitly to learn variable interactions over time, without needing to specify the longitudinal process beforehand. As such, the changing influence on survival of variables over time is inferred from the data directly, which we analyze with post-processing statistics derived from our model.

We study the problem of personalizing survival estimates of patients in heterogeneous populations for clinical decision support. The desiderata are to improve predictions by making them personalized to the patient-at-hand, to better understand diseases and their risk factors, and to provide interpretable model outputs to clinicians.

To enable accurate survival prognosis in heterogeneous populations we propose a novel probabilistic survival model which flexibly captures individual traits through a hierarchical latent variable formulation. Survival paths are estimated by jointly sampling the location and shape of the individual survival distribution resulting in patient-specific curves with quantifiable uncertainty estimates. An understanding of model predictions is paramount in medical practice where decisions have major social consequences.

We develop a personalized interpreter that can be used to test the effect of covariates on each individual patient, in contrast to traditional methods that focus on population average effects.

We extensively validated the proposed approach in various clinical settings, with a special focus on cardiovascular disease.

Building personalized comorbidity networks

The causal structure of relationships between diseases can be represented by dynamic networks, with the strength of the edges between nodes changes over time depending on the patient’s entire history. In most cases, the underlying network dynamics are unknown, and what we observe are sequences of events spreading over the network. To infer the latent network dynamics from observed sequences, one needs to take into account both when and what events occurred in the past since both carry information on the mechanisms involved in disease instantiation and progression.

This is a challenge we addressed through the development of deep diffusion processes (DDP) to model dynamic comorbidity networks. Our work in this area is showcased in a paper published at AISTATS 2020.

DDP offers a deep probabilistic model for diffusion over comorbidity networks based on mutually-interacting point processes. We modeled DDP’s intensity function as a combination of contextualized background risk and networked disease interaction, using a deep neural network to (dynamically) update the disease’s influence on future events. This enables principled predictions based on clinically interpretable parameters which map patient history on to personalized comorbidity networks.

The dynamic comorbidity network learned by DDP for an individual patient at three time steps, together with the corresponding intensity function. Nodes for diseases that have not occurred are colored in gray, and diseases already diagnosed are assigned a distinct color. Edge thickness corresponds to the disease likelihood at the given time step. In the upper left panel, we plot the Jaccard distance of the patient’s network with respect to the average population as a function of time (on a logarithmic scale). The static comorbidity network obtained by counting disease co-occurrences and using the counts as graph edges is depicted on the right panel.

Learning dynamic and personalized comorbidity networks from event data using deep diffusion processes

Zhaozhi Qian, Ahmed Alaa, Alexis Bellot, Mihaela van der Schaar, Jem Rashbass

AISTATS 2020

Comorbid diseases co-occur and progress via complex temporal patterns that vary among individuals. In electronic medical records, we only observe onsets of diseases, but not their triggering comorbidities — i.e., the mechanisms underlying temporal relations between diseases need to be inferred. Learning such temporal patterns from event data is crucial for understanding disease pathology and predicting prognoses.

To this end, we develop deep diffusion processes (DDP) to model ’dynamic comorbidity networks’, i.e., the temporal relationships between comorbid disease onsets expressed through a dynamic graph.

A DDP comprises events modelled as a multi-dimensional point process, with an intensity function parameterized by the edges of a dynamic weighted graph. The graph structure is modulated by a neural network that maps patient history to edge weights, enabling rich temporal representations for disease trajectories. The DDP parameters decouple into clinically meaningful components, which enables serving the dual purpose of accurate risk prediction and intelligible representation of disease pathology.

We illustrate these features in experiments using cancer registry data.

Phenotyping and subgroup identification

Phenotyping and identifying subgroups of patients are important challenges in survival analysis. These become particularly complicated in a dynamic setting where longitudinal datasets are in use.

Chronic diseases such as cystic fibrosis and dementia are heterogeneous in nature, with widely differing outcomes—even in narrow patient subgroups. Identifying patient subgroups with similar progression patterns can be advantageous for understanding such heterogeneous diseases. This allows clinicians to anticipate patients’ prognoses by comparing to “similar” patients and to design treatment guidelines tailored to homogeneous subgroups.

Temporal clustering has been recently used as a data-driven framework to partition patients with time-series observations into subgroups of patients. Recent research has typically focused on either finding fixed-length and low-dimensional representations or on modifying the similarity measure, both in an attempt to apply the existing clustering algorithms to time-series observations.

However, clusters identified from these approaches are purely unsupervised – they do not account for patients’ observed outcomes (e.g., adverse events, the onset of comorbidities, etc.) – which leads to heterogeneous clusters if the clinical presentation of the disease differs even for patients with the same outcomes. Thus, a common prognosis in each cluster remains unknown, which can obscure the underlying disease progression.

To overcome this limitation, our lab developed an approach based on predictive clustering with the aim of combining predictions of future outcomes with clustering. This approach to temporal phenotyping using deep predictive clustering was introduced in a paper published at ICML 2020.

Temporal clustering is challenging because i) the data is often high-dimensional it consists of sequences not only with high-dimensional features but also with many time points and ii) defining a proper similarity measure for time-series is not straightforward since it is often highly sensitive to distortions.

Our approach here uses an actor-critic approach for temporal predictive clustering.  The key insight here is that we model temporal predictive clustering as learning discrete representations of the input time-series. More specifically, input time-series are mapped into continuous latent encodings which then assigned to clusters (i.e., selected discrete representations) that best describe the future outcome distribution based on novel loss functions.

Survival analysis - predictive clustering

A conceptual illustration of the (real-time) temporal phenotyping procedure. In this example, we focus on patients diagnosed with breast cancer where the clinical outcomes of our interest are the recurrence of cancer and cancer-related death. During run-time, the new patient is assigned to one of three phenotypes as new observations are collected over time. In our outcome-oriented notion of phenotyping, the new patient is assigned to phenotype 2 at time t1, which consists of past patients with a high risk of recurrence. Then, this new patient is assigned to phenotype 3 at time t2, which consists of past patients who died from cancer-related causes in the near future. However, in the traditional notion of phenotyping, the new patient remains at the same phenotype at both t1 and t2 since the longitudinal observations remain similar to past patients of this phenotype.

Temporal phenotyping using deep predictive clustering of disease progression

Changhee Lee, Mihaela van der Schaar

ICML 2020

Due to the wider availability of modern electronic health records, patient care data is often being stored in the form of time-series. Clustering such time-series data is crucial for patient phenotyping, anticipating patients’ prognoses by identifying “similar” patients, and designing treatment guidelines that are tailored to homogeneous patient subgroups.

In this paper, we develop a deep learning approach for clustering time-series data, where each cluster comprises patients who share similar future outcomes of interest (e.g., adverse events, the onset of comorbidities). To encourage each cluster to have homogeneous future outcomes, the clustering is carried out by learning discrete representations that best describe the future outcome distribution based on novel loss functions.

Experiments on two real-world datasets show that our model achieves superior clustering performance over state-of-the-art benchmarks and identifies meaningful clusters that can be translated into actionable information for clinical decision-making.

Chronic diseases evolve slowly throughout a patient’s lifetime creating heterogeneous progression patterns that make clinical outcomes remarkably varied across individual patients. A tool capable of identifying temporal phenotypes based on the patients’ different progression patterns and clinical outcomes would allow clinicians to better forecast disease progression by recognizing a group of similar past patients, and to better design treatment guidelines that are tailored to specific phenotypes.

To build such a tool, we propose a deep learning approach, which we refer to as outcome-oriented deep temporal phenotyping (ODTP), to identify temporal phenotypes of disease progression considering what type of clinical outcomes will occur and when based on the longitudinal observations. More specifically, we model clinical outcomes throughout a patient’s longitudinal observations via time-to-event (TTE) processes whose conditional intensity functions are estimated as non-linear functions using a recurrent neural network. Temporal phenotyping of disease progression is carried out by our novel loss function that is specifically designed to learn discrete latent representations that best characterize the underlying TTE processes.

The key insight here is that learning such discrete representations groups progression patterns considering the similarity in expected clinical outcomes, and thus naturally provides outcome-oriented temporal phenotypes.

We demonstrate the power of ODTP by applying it to a real-world heterogeneous cohort of 11,779 stage III breast cancer patients from the UK National Cancer Registration and Analysis Service. The experiments show that ODTP identifies temporal phenotypes that are strongly associated with the future clinical outcomes and achieves significant gain on the homogeneity and heterogeneity measures over existing methods.

Furthermore, we are able to identify the key driving factors that lead to transitions between phenotypes which can be translated into actionable information to support better clinical decision-making.

New metrics of performance

The sections above have offered simple introductions to some of the challenges inherent in modeling survival functions in static and dynamic settings, and with one or multiple events of interest. We would like to conclude by briefly discussing the measurement of performance of models for survival analysis—in many ways, this is as important and challenging an area as the development of the models themselves.

One widely used metric to measure the discriminative ability of a model for survival analysis with competing risks is the concordance index. Introduced in 1982, the concordance index aims to measure the performance of a medical test by determining how much prognostic information the test can provide about an individual patient.

While the concordance index can be effectively used to assess the prognostic ability of a model for one event type of interest in the presence of competing risks, it does not consist of information regarding whether a model is good at predicting the event type as well, making it ill-suited in applications where there is more than one event type of interest. In such applications (such as treatment planning in multimorbid populations or resource planning in critical care), the evaluation of a model’s ability to jointly predict the event type and the event time is often critical.

To address this shortcoming, our lab developed a new metric that we call the joint concordance index. The joint concordance index is the probability that a given model accurately predicts the event type for a subject while also ranking that subject’s risk correctly among the other subjects. The index can be interpreted by decomposing it into a metric that is similar to accuracy and concordance conditional on the correct predictions.

In our 2019 paper (below) introducing the joint concordance index, we demonstrated that existing approaches for variable importance ranking often fail to recognize the importance of event-specific risk factors, whereas the joint concordance index does not, since it compares risk factors based on their contribution to the prediction of different event types.

Joint concordance index

Kartik Ahuja, Mihaela van der Schaar

2019 Asilomar Conference on Signals, Systems, and Computers

Existing metrics in competing risks survival analysis such as concordance and accuracy do not evaluate a model’s ability to jointly predict the event type and the event time.

To address these limitations, we propose a new metric, which we call the joint concordance. The joint concordance measures a model’s ability to predict the overall risk profile, i.e., risk of death from different event types. We develop a consistent estimator for the new metric that accounts for the censoring bias. We use the new metric to develop a variable importance ranking approach.

Using the real and synthetic data experiments, we show that models selected using the existing metrics are worse than those selected using joint concordance at jointly predicting the event type and event time. We show that the existing approaches for variable importance ranking often fail to recognize the importance of the event-specific risk factors, whereas, the proposed approach does not, since it compares risk factors based on their contribution to the prediction of the different event-types.

To summarize, joint concordance is helpful for model comparisons and variable importance ranking and has the potential to impact applications such as risk-stratification and treatment planning in multimorbid populations.

Find out more and get involved

This page has served as an introduction to survival analysis, competing risks, and comorbidities—from the perspective of both healthcare and machine learning.

We have demonstrated the challenges inherent in creating (and evaluating) effective models that are able to work across static and dynamic settings, can predict multiple events of interest, and are able to factor in the evolving interactions and causal relationships between multiple competing risks and comorbidities. Tools such as those above are key to accelerating the advent of “bespoke medicine” and truly moving beyond one-size-fits-all approaches.

We would also encourage you to stay abreast of ongoing developments in this and other areas of machine learning for healthcare by signing up to take part in one of our two streams of online engagement sessions.

If you are a practicing clinician, please sign up for Revolutionizing Healthcare, which is a forum for members of the clinical community to share ideas and discuss topics that will define the future of machine learning in healthcare (no machine learning experience required).

If you are a machine learning student, you can join our Inspiration Exchange engagement sessions, in which we introduce and discuss new ideas and development of new methods, approaches, and techniques in machine learning for healthcare.

A full list of our papers on this and related topics can be found here.