van der Schaar Lab

Spotlight on cancer research projects

To confront cancer is to encounter a parallel species,
one perhaps more adapted to survival than even we are.

– Siddhartha Mukherjee, The Emperor of All Maladies

The term cancer embraces a wide variety of related disorders/conditions that share as many similarities but also as many differences. This daunting complexity becomes more apparent with every breakthrough in our quest to understand it. This complexity is manifold, ranging from the bewildering array of disease subtypes (and subtypes of subtypes) to variations in cause and presentation, to the lengthy and unpredictable pathways inflicted on patients.

While the notion of developing a single “magic bullet” to cure cancer is outdated, ongoing research advancements have at least allowed us to develop a substantial arsenal in areas such as prevention, prediction, detection, diagnosis, treatment, and care. Truly revolutionizing our ability to combat cancer, however, requires an altogether deeper understanding of its disease pathways, and I believe this can only be achieved through the adoption of machine learning methods.

The potential of machine learning in combating cancer is a topic I addressed in our most recent Revolutionizing Healthcare engagement session. To view that session (in which I also explain many of the methods and approaches detailed later in this post), or to sign up for the Revolutionizing Healthcare series, please click the links below.

This post will highlight and summarize some of our lab’s key projects related to cancer. Our summary will follow a slightly simplified chronological representation of the standard cancer patient’s pathway: at each stage along the pathway, we introduce specific projects and provide resources for further reading.

Genetic risk

Polygenic risk scores play an important role in determining an individual’s risk of developing cancer during their lifetime. To date, only linear models have been successfully applied to crafting genomic risk scores used genomics data. This raises the question: can machine learning help further improve the crafting of polygenic risk scores by comparison with linear models?

To address this, our lab recently created VIME, a machine learning framework for crafting polygenic risk scores combining self- and semi-supervised learning.

VIME: Extending the Success of Self- and Semi-supervised Learning to Tabular Domain

Jinsung Yoon, Yao Zhang, James Jordon, Mihaela van der Schaar

Self- and semi-supervised learning frameworks have made significant progress in training machine learning models with limited labeled data in image and language domains. These methods heavily rely on the unique structure in the domain datasets (such as spatial relationships in images or semantic relationships in language). They are not adaptable to general tabular data which does not have the same explicit structure as image and language data.

In this paper, we fill this gap by proposing novel self- and semi-supervised learning frameworks for tabular data, which we refer to collectively as VIME (Value Imputation and Mask Estimation). We create a novel pretext task of estimating mask vectors from corrupted tabular data in addition to the reconstruction pretext task for self-supervised learning. We also introduce a novel tabular data augmentation method for self- and semi-supervised learning frameworks.

In experiments, we evaluate the proposed framework in multiple tabular datasets from various application domains, such as genomics and clinical data. VIME exceeds state-of-the-art performance in comparison to the existing baseline methods.


In addition to genetic factors, lifestyle (including socio-demographics) plays a major role in determining an individual’s cancer risk.

While current statistical risk scoring models only use a handful of factors that have been identified as potentially important, we know that there are other factors that may be just as important (or more important).

In this context, we can apply machine learning in the service of two main objectives:
1) identifying, out of a large number of potentially informative risk factors (including socio-demographic information), which factors are most relevant for issuing an accurate prediction, in essence determining the value of information for a particular individual or class of individuals; and
2) understanding when non-linear interactions between identified factors are important and moving beyond linear models to non-linear models.

For this, we developed AutoPrognosis, our machine learning tool for crafting clinical scores. You can learn more about AutoPrognosis here, or by reading the paper below.

AutoPrognosis: Automated Clinical Prognostic Modeling via Bayesian Optimization
with Structured Kernel Learning

Ahmed Alaa, Mihaela van der Schaar

Clinical prognostic models derived from largescale healthcare data can inform critical diagnostic and therapeutic decisions.

To enable off-the-shelf usage of machine learning (ML) in prognostic research, we developed AutoPrognosis: a system for automating the design of predictive modeling pipelines tailored for clinical prognosis. AutoPrognosis optimizes ensembles of pipeline configurations efficiently using a novel batched Bayesian optimization (BO) algorithm that learns a low-dimensional decomposition of the pipelines’ high-dimensional hyperparameter space in concurrence with the BO procedure. This is achieved by modeling the pipelines’ performances as a black-box function with a Gaussian process prior, and modeling the “similarities” between the pipelines’ baseline algorithms via a sparse additive kernel with a Dirichlet prior. Meta-learning is used to warmstart BO with external data from “similar” patient cohorts by calibrating the priors using an algorithm that mimics the empirical Bayes method. The system automatically explains its predictions by presenting the clinicians with logical association rules that link patients’ features to predicted risk strata.

We demonstrate the utility of AutoPrognosis using 10 major patient cohorts representing various aspects of cardiovascular patient care.


Early diagnosis

For early diagnosis, we use available healthcare data to understand progression of health and disease trajectories. This is very important in order to be able to identify cancer early in a patient.

So we are going to use machine learning and the wealth of input data available about the patient (symptoms, clinical findings, imaging results, lab tests, possible treatments given, and the timing of all of these) to issue predictions and forecasts, including early diagnosis of onset of cancer and potentially (upon diagnosis) severity of disease progression, etc.

For this, we need to use machine learning to build data-driven dynamic forecasting models that are personalized, accurate, and interpretable. These can be used for early diagnosis, as well as (if cancer has been identified) personalized monitoring and forecasting disease progression.

A hidden absorbing semi-Markov model for informatively censored temporal data:
learning and inference

Ahmed Alaa, Mihaela van der Schaar

Modeling continuous-time physiological processes that manifest a patient’s evolving clinical states is a key step in approaching many problems in healthcare.

In this paper, we develop the Hidden Absorbing Semi-Markov Model (HASMM): a versatile probabilistic model that is capable of capturing the modern electronic health record (EHR) data. Unlike existing models, the HASMM accommodates irregularly sampled, temporally correlated, and informatively censored physiological data, and can describe non-stationary clinical state transitions.

Learning the HASMM parameters from the EHR data is achieved via a novel forward-filtering backward-sampling Monte-Carlo EM algorithm that exploits the knowledge of the end-point clinical outcomes (informative censoring) in the EHR data, and implements the E-step by sequentially sampling the patients’ clinical states in the reverse-time direction while conditioning on the future states. Real-time inferences are drawn via a forward-filtering algorithm that operates on a virtually constructed discrete-time embedded Markov chain that mirrors the patient’s continuous-time state trajectory.

We demonstrate the prognostic utility of the HASMM in a critical care prognosis setting using a real-world dataset for patients admitted to the Ronald Reagan UCLA Medical Center. In particular, we show that using HASMMs, a patient’s clinical deterioration can be predicted 8-9 hours prior to intensive care unit admission, with a 22% AUC gain compared to the Rothman index, which is the state-of-the-art critical care risk scoring technology.

Attentive State-Space Modeling of Disease Progression

Ahmed Alaa, Mihaela van der Schaar

Models of disease progression are instrumental for predicting patient outcomes and understanding disease dynamics. Existing models provide the patient with pragmatic (supervised) predictions of risk, but do not provide the clinician with intelligible (unsupervised) representations of disease pathophysiology.

In this paper, we develop the attentive state-space model, a deep probabilistic model that learns accurate and interpretable structured representations for disease trajectories. Unlike Markovian state-space models, in which the dynamics are memoryless, our model uses an attention mechanism to create “memoryful” dynamics, whereby attention weights determine the dependence of future disease states on past medical history.

To learn the model parameters from medical records, we develop an infer ence algorithm that simultaneously learns a compiled inference network and the model parameters, leveraging the attentive state-space representation to construct a “Rao-Blackwellized” variational approximation of the posterior state distribution.

Experiments on data from the UK Cystic Fibrosis registry show that our model demonstrates superior predictive accuracy and provides insights into the progression of chronic disease.

Dynamic personalized screening

As mentioned above, screening is another critical part of the presentation stage of the cancer pathway. Medicine has been moving from a one-size-fits all approach towards dynamic personalized screening.

This is the approach we took in our work on DPSCREEN, a technology we developed a few years ago. DPSCREEN takes into account both the features (unique characteristics) of an individual and their past clinical and screening history.

DPSCREEN: Dynamic Personalized Screening

Kartik Ahuja, William Zame, Mihaela van der Schaar

Screening is important for the diagnosis and treatment of a wide variety of diseases. A good screening policy should be personalized to the disease, to the features of the patient and to the dynamic history of the patient (including the history of screening). The growth of electronic health records data has led to the development of many models to predict the onset and progression of different diseases. However, there has been limited work to address the personalized screening for these different diseases.

In this work, we develop the first framework to construct screening policies for a large class of disease models. The disease is modeled as a finite state stochastic process with an absorbing disease state. The patient observes an external information process (for instance, self-examinations, discovering comorbidities, etc.) which can trigger the patient to arrive at the clinician earlier than scheduled screenings. The clinician carries out the tests; based on the test results and the external information it schedules the next arrival. Computing the exactly optimal screening policy that balances the delay in the detection against the frequency of screenings is computationally intractable; this paper provides a computationally tractable construction of an approximately optimal policy.

As an illustration, we make use of a large breast cancer data set. The constructed policy screens patients more or less often according to their initial risk — it is personalized to the features of the patient — and according to the results of previous screens – it is personalized to the history of the patient. In comparison with existing clinical policies, the constructed policy leads to large reductions (28-68 %) in the number of screens performed while achieving the same expected delays in disease detection.

Using (co)morbidities to prevent or identify cancer

At present, morbidities and comorbidities are modeled in a one-size-fits all, static fashion. This is often based on networks of relationships between these different morbidities.

Using machine learning, we are able to predict the likelihood of an individual developing a new morbidity, such as cancer, in the future. This can be done through the use of morbidity networks that are both personalized (i.e. they depend on the unique characteristics, such as genetic information, of each specific individual) and dynamic (i.e. they depend on the order in which morbidities occur).

Deep diffusion processes (DDP), developed by our lab last year, allow us to model the relationships between comorbid disease onsets expressed through a dynamic graph, meaning we can predict the onset of a new disease.

Learning Dynamic and Personalized Comorbidity Networks from Event Data
using Deep Diffusion Processes

Zhaozhi Qian, Ahmed Alaa, Alexis Bellot, Mihaela Schaar, Jem Rashbass

Comorbid diseases co-occur and progress via complex temporal patterns that vary among individuals. In electronic medical records, we only observe onsets of diseases, but not their triggering comorbidities — i.e., the mechanisms underlying temporal relations between diseases need to be inferred. Learning such temporal patterns from event data is crucial for understanding disease pathology and predicting prognoses.

To this end, we develop deep diffusion processes (DDP) to model ’dynamic comorbidity networks’, i.e., the temporal relationships between comorbid disease onsets expressed through a dynamic graph. A DDP comprises events modelled as a multi-dimensional point process, with an intensity function parameterized by the edges of a dynamic weighted graph. The graph structure is modulated by a neural network that maps patient history to edge weights, enabling rich temporal representations for disease trajectories. The DDP parameters decouple into clinically meaningful components, which enables serving the dual purpose of accurate risk prediction and intelligible representation of disease pathology.

We illustrate these features in experiments using cancer registry data.


Machine learning’s ability to assist with diagnosis has been particularly well-documented—especially with regard to areas such as imaging. In this post, I would like to move beyond those impressive but well-trodden paths, and consider how machine learning can improve a range of overall diagnostic processes and empower the human professionals behind those processes.

Triaging in the diagnosis process

A key priority in cancer diagnosis is managing the workload of radiologists to optimize accuracy, efficiency, and costs. Our challenge here is to ensure that radiologists can devote the right amount of time to viewing scans that actually need their attention, meaning such scans must be separated out from others which can simply be read using machine learning or similar technologies.

MAMMO is a framework for cooperation between radiologists and machine learning. The focus of MAMMO is to triage mammograms between machine learning systems and radiologists.

Improving Workflow Efficiency for Mammography using Machine Learning

Trent Kyono, Fiona J Gilbert, Mihaela van der Schaar

Objective: The aim of this study was to determine whether machine learning could reduce the number of mammograms the radiologist must read by using a machine-learning classifier to correctly identify normal mammograms and to select the uncertain and abnormal examinations for radiological interpretation.

Methods: Mammograms in a research data set from over 7,000 women who were recalled for assessment at six UK National Health Service Breast Screening Program centers were used. A convolutional neural network in conjunction with multitask learning was used to extract imaging features from mammograms that mimic the radiological assessment provided by a radiologist, the patient’s nonimaging features, and pathology outcomes. A deep neural network was then used to concatenate and fuse multiple mammogram views to predict both a diagnosis and a recommendation of whether or not additional radiological assessment was needed.

Results: Ten-fold cross-validation was used on 2,000 randomly selected patients from the data set; the remainder of the data set was used for convolutional neural network training. While maintaining an acceptable negative predictive value of 0.99, the proposed model was able to identify 34% (95% confidence interval, 25%-43%) and 91% (95% confidence interval: 88%-94%) of the negative mammograms for test sets with a cancer prevalence of 15% and 1%, respectively.

Conclusion: Machine learning was leveraged to successfully reduce the number of normal mammograms that radiologists need to read without degrading diagnostic accuracy.

Determining personalized screening modality

Our lab has also developed a system called ConfidentCare, which, like MAMMO, aims to improve accuracy and efficiency of resource usage within the overall diagnostic process.

ConfidentCare is a machine learning clinical decision support system that identifies what type of screening modality (e.g. mammogram, ultrasound, MRI) should be used for specific individuals, given their unique characteristics such as genomic information or past screening history.

ConfidentCare: A Clinical Decision Support System for Personalized Breast Cancer Screening

Ahmed Alaa, Kyeong H. Moon, William Hsu, Mihaela van der Schaar

Breast cancer screening policies attempt to achieve timely diagnosis by regularly screening healthy women via various imaging tests. Various clinical decisions are needed to manage the screening process: selecting initial screening tests, interpreting test results, and deciding if further diagnostic tests are required.

Current screening policies are guided by clinical practice guidelines (CPGs), which represent a “one-size-fits-all” approach, designed to work well (on average) for a population, and can only offer coarse expert-based patient stratification that is not rigorously validated through data. Since the risks and benefits of screening tests are functions of each patient’s features,personalized screening policies tailored to the features of individuals are desirable.

To address this issue, we developed ConfidentCare: a computer-aided clinical decision support system that learns a personalized screening policy from electronic health record (EHR) data. By a “personalized screening policy,” we mean a clustering of women’s features, and a set of customized screening guidelines for each cluster. ConfidentCare operates by computing clusters of patients with similar features, then learning the “best” screening procedure for each cluster using a supervised learning algorithm. The algorithm ensures that the learned screening policy satisfies a predefined accuracy requirement with a high level of confidence for every cluster.

By applying ConfidentCare to real-world data, we show that it outperforms the current CPGs in terms of cost efficiency and false positive rates: a reduction of 31% in the false positive rate can be achieved.

Referral and composition of multidisciplinary teams

Determining the composition of multidisciplinary teams (MDTs) can be one of the most complex parts of the diagnosis and treatment process.

This is a process that can be made substantially more efficient and effective through the use of machine learning-enabled recommender systems. These systems can identify which clinicians should come together to best decide treat treatment options for a cancer patient, based on particular patient and clinician characteristics.

A few years ago, we built a recommender system that can “discover the experts” by assessing the context of the patient and determining the characteristics required of individual clinicians within the MDT—as well as determining the kind of machine learning decision support tools that should be used by this specific MDT for this specific patient.

Discover the Expert: Context-Adaptive Expert Selection for Medical Diagnosis

Cem Tekin, Onur Atan, Mihaela Van Der Schaar

In this paper, we propose an expert selection system that learns online the best expert to assign to each patient depending on the context of the patient.

In general, the context can include an enormous number and variety of information related to the patient’s health condition, age, gender, previous drug doses, and so forth, but the most relevant information is embedded in only a few contexts. If these most relevant contexts were known in advance, learning would be relatively simple but they are not. Moreover, the relevant contexts may be different for different health conditions.

To address these challenges, we develop a new class of algorithms aimed at discovering the most relevant contexts and the best clinic and expert to use to make a diagnosis given a patient’s contexts. We prove that as the number of patients grows, the proposed context-adaptive algorithm will discover the optimal expert to select for patients with a specific context. Moreover, the algorithm also provides confidence bounds on the diagnostic accuracy of the expert it selects, which can be considered by the primary care physician before making the final decision.

While our algorithm is general and can be applied in numerous medical scenarios, we illustrate its functionality and performance by applying it to a real-world breast cancer diagnosis data set. Finally, while the application we consider in this paper is medical diagnosis, our proposed algorithm can be applied in other environments where expertise needs to be discovered.

Competing risks

We can also use machine learning to analyze competing risks; this can be done not only for one particular type of cancer but also other related cancers (for example, breast cancer and ovarian cancer), or different types of diseases (for example, cancer and cardiovascular disease). This lets us better determine screening profiles by adjusting cause-specific predictions, while also managing and prioritizing preventative treatments further down the line.

For this, we have developed a range of methods, some of which are shown below.

Deep Multi-task Gaussian Processes for Survival Analysis with Competing Risks

Ahmed Alaa, Mihaela van der Schaar

Designing optimal treatment plans for patients with comorbidities requires accurate cause-specific mortality prognosis.

Motivated by the recent availability of linked electronic health records, we develop a nonparametric Bayesian model for survival analysis with competing risks, which can be used for jointly assessing a patient’s risk of multiple (competing) adverse outcomes. The model views a patient’s survival times with respect to the competing risks as the outputs of a deep multi-task Gaussian process (DMGP), the inputs to which are the patients’ covariates.

Unlike parametric survival analysis methods based on Cox and Weibull models, our model uses DMGPs to capture complex non-linear interactions between the patients’ covariates and cause-specific survival times, thereby learning flexible patient-specific and cause-specific survival curves, all in a data-driven fashion without explicit parametric assumptions on the hazard rates. We propose a variational inference algorithm that is capable of learning the model parameters from time-to-event data while handling right censoring.

Experiments on synthetic and real data show that our model outperforms the state-of-the-art survival models.

DeepHit: A Deep Learning Approach to Survival Analysis With Competing Risks

Changhee Lee, William R. Zame, Jinsung Yoon, Mihaela van der Schaar

Survival analysis (time-to-event analysis) is widely used in economics and finance, engineering, medicine and many other areas. A fundamental problem is to understand the relationship between the covariates and the (distribution of) survival times(times-to-event).

Much of the previous work has approached the problem by viewing the survival time as the first hitting time of a stochastic process, assuming a specific form for the underlying stochastic process, using available data to learn the relationship between the covariates and the parameters of the model, and then deducing the relationship between covariates and the distribution of first hitting times (the risk). However, previous models rely on strong parametric assumptions that are often violated.

This paper proposes a very different approach to survival analysis, DeepHit, that uses a deep neural network to learn the distribution of survival times directly.DeepHit makes no assumptions about the underlying stochastic process and allows for the possibility that the relationship between covariates and risk(s) changes over time. Most importantly, DeepHit smoothly handles competing risks; i.e. settings in which there is more than one possible event of interest.

Comparisons with previous models on the basis of real and synthetic datasets demonstrate that DeepHit achieves large and statistically significant performance improvements over previous state-of-the-art methods.

Multitask Boosting for Survival Analysis with Competing Risks

Alexis Bellot, Mihaela van der Schaar

The co-occurrence of multiple diseases among the general population is an important problem as those patients have more risk of complications and represent a large share of health care expenditure. Learning to predict time-to-event probabilities for these patients is a challenging problem because the risks of events are correlated (there are competing risks) with often only few patients experiencing individual events of interest, and of those only a fraction are actually observed in the data.

We introduce in this paper a survival model with the flexibility to leverage a common representation of related events that is designed to correct for the strong imbalance in observed outcomes. The procedure is sequential: outcome-specific survival distributions form the components of nonparametric multivariate estimators which we combine into an ensemble in such a way as to ensure accurate predictions on all outcome types simultaneously.

Our algorithm is general and represents the first boosting-like method for time-to-event data with multiple outcomes. We demonstrate the performance of our algorithm on synthetic and real data.

Temporal Quilting for Survival Analysis

Changhee Lee, William Zame, Ahmed Alaa, Mihaela van der Schaar

The importance of survival analysis in many disciplines (especially in medicine) has led to the development of a variety of approaches to modeling the survival function. Models constructed via various approaches offer different strengths and weaknesses in terms of discriminative performance and calibration, but no one model is best across all datasets or even across all time horizons within a single dataset. Because we require both good calibration and good discriminative performance over different time horizons, conventional model selection and ensemble approaches are not applicable.

This paper develops a novel approach that combines the collective intelligence of different underlying survival models to produce a valid survival function that is well-calibrated and offers superior discriminative performance at different time horizons.

Empirical results show that our approach provides significant gains over the benchmarks on a variety of real-world datasets.

Application of a novel machine learning framework for predicting non-metastatic prostate cancer-specific mortality in men using the Surveillance, Epidemiology, and End Results (SEER) database

Changhee Lee, Alexander Light, Ahmed Alaa, David Thurtle, Mihaela van der Schaar, Vincent J Gnanapragasam

Background: Accurate prognostication is crucial in treatment decisions made for men diagnosed with non-metastatic prostate cancer. Current models rely on prespecified variables, which limits their performance. We aimed to investigate a novel machine learning approach to develop an improved prognostic model for predicting 10-year prostate cancer-specific mortality and compare its performance with existing validated models.

Methods: We derived and tested a machine learning-based model using Survival Quilts, an algorithm that automatically selects and tunes ensembles of survival models using clinicopathological variables. Our study involved a US population-based cohort of 171 942 men diagnosed with non-metastatic prostate cancer between Jan 1, 2000, and Dec 31, 2016, from the prospectively maintained Surveillance, Epidemiology, and End Results (SEER) Program. The primary outcome was prediction of 10-year prostate cancer-specific mortality. Model discrimination was assessed using the concordance index (c-index), and calibration was assessed using Brier scores. The Survival Quilts model was compared with nine other prognostic models in clinical use, and decision curve analysis was done.

Findings: 647 151 men with prostate cancer were enrolled into the SEER database, of whom 171 942 were included in this study. Discrimination improved with greater granularity, and multivariable models outperformed tier-based models. The Survival Quilts model showed good discrimination (c-index 0·829, 95% CI 0·820–0·838) for 10-year prostate cancer-specific mortality, which was similar to the top-ranked multivariable models: PREDICT Prostate (0·820, 0·811–0·829) and Memorial Sloan Kettering Cancer Center (MSKCC) nomogram (0·787, 0·776–0·798). All three multivariable models showed good calibration with low Brier scores (Survival Quilts 0·036, 95% CI 0·035–0·037; PREDICT Prostate 0·036, 0·035–0·037; MSKCC 0·037, 0·035–0·039). Of the tier-based systems, the Cancer of the Prostate Risk Assessment model (c-index 0·782, 95% CI 0·771–0·793) and Cambridge Prognostic Groups model (0·779, 0·767–0·791) showed higher discrimination for predicting 10-year prostate cancer-specific mortality. c-indices for models from the National Comprehensive Cancer Care Network, Genitourinary Radiation Oncologists of Canada, American Urological Association, European Association of Urology, and National Institute for Health and Care Excellence ranged from 0·711 (0·701–0·721) to 0·761 (0·750–0·772). Discrimination for the Survival Quilts model was maintained when stratified by age and ethnicity. Decision curve analysis showed an incremental net benefit from the Survival Quilts model compared with the MSKCC and PREDICT Prostate models currently used in practice.

Interpretation: A novel machine learning-based approach produced a prognostic model, Survival Quilts, with discrimination for 10-year prostate cancer-specific mortality similar to the top-ranked prognostic models, using only standard clinicopathological variables. Future integration of additional data will likely improve model performance and accuracy for personalised prognostics.


Truly personalized healthcare (which we refer to as “bespoke medicine”) goes far being providing predictions for individual patients: we also need to understand the effect of specific treatments on specific patients at specific times. This is what we call individualized treatment effect inference. It is a substantially more complex undertaking than prediction, and every bit as important—particularly in treating a disease like cancer, since no two patients will have the same cancer pathway.

When deciding on a treatment for a given form of cancer, clinical decisions are often made on the basis of results from randomized controlled trials of treatments involving that cancer. This approach assumes a response to treatment based on the response of the “average patient,” rather than taking into account the health history and specific features of the individual.

Rather than making treatment decisions based on such blanket assumptions, the goal of clinical decision-makers has shifted to determining the optimal treatment course for any given patient at any given time. Methods for doing so in a quantitative fashion based on insights from machine learning are in the formative stages of development, and our lab has built a position of leadership in this area. We have defined the research agenda by outlining and addressing key complexities and challenges, and by laying the theoretical groundwork for model development.

To read more about individualized treatment effect inference, visit our dedicated page on the topic.

GANITE: Estimation of Individualized Treatment Effects using Generative Adversarial Nets

Jinsung Yoon, James Jordon, Mihaela van der Schaar

Estimating individualized treatment effects (ITE) is a challenging task due to the need for an individual’s potential outcomes to be learned from biased data and without having access to the counterfactuals.

We propose a novel method for inferring ITE based on the Generative Adversarial Nets (GANs) framework. Our method, termed Generative Adversarial Nets for inference of Individualized Treatment Effects (GANITE), is motivated by the possibility that we can capture the uncertainty in the counterfactual distributions by attempting to learn them using a GAN. We generate proxies of the counterfactual outcomes using a counterfactual generator, G, and then pass these proxies to an ITE generator, I, in order to train it. By modeling both of these using the GAN framework, we are able to infer based on the factual data, while still accounting for the unseen counterfactuals.

We test our method on three real-world datasets (with both binary and multiple treatments) and show that GANITE outperforms state-of-the-art methods.

Limits of Estimating Heterogeneous Treatment Effects:
Guidelines for Practical Algorithm Design

Ahmed Alaa, Mihaela van der Schaar

Estimating heterogeneous treatment effects from observational data is a central problem in many domains. Because counterfactual data is inaccessible, the problem differs fundamentally from supervised learning, and entails a more complex set of modeling choices. Despite a variety of recently proposed algorithmic solutions, a principled guideline for building estimators of treatment effects using machine learning algorithms is still lacking.

In this paper, we provide such a guideline by characterizing the fundamental limits of estimating heterogeneous treatment effects, and establishing conditions under which these limits can be achieved. Our analysis reveals that the relative importance of the different aspects of observational data vary with the sample size. For instance, we show that selection bias matters only in small-sample regimes, whereas with a large sample size, the way an algorithm models the control and treated outcomes is what bottlenecks its performance. Guided by our analysis, we build a practical algorithm for estimating treatment effects using a non-stationary Gaussian processes with doubly-robust hyperparameters.

Using a standard semi-synthetic simulation setup, we show that our algorithm outperforms the state-of-the-art, and that the behavior of existing algorithms conforms with our analysis.

Forecasting Treatment Responses Over Time Using Recurrent Marginal Structural Networks

Bryan Lim, Ahmed Alaa, Mihaela Van Der Schaar

Electronic health records provide a rich source of data for machine learning methods to learn dynamic treatment responses over time. However, any direct estimation is hampered by the presence of time-dependent confounding, where actions taken are dependent on time-varying variables related to the outcome of interest.

Drawing inspiration from marginal structural models, a class of methods in epidemiology which use propensity weighting to adjust for time-dependent confounders, we introduce the Recurrent Marginal Structural Network – a sequence-to-sequence architecture for forecasting a patient’s expected response to a series of planned treatments.

Using simulations of a state-of-the-art pharmacokinetic-pharmacodynamic (PK-PD) model of tumor growth, we demonstrate the ability of our network to accurately learn unbiased treatment responses from observational data – even under changes in the policy of treatment assignments – and performance gains over benchmarks.

Estimating counterfactual treatment outcomes over time
through adversarially balanced representations

Ioana Bica, Ahmed Alaa, James Jordon, Mihaela van der Schaar

Identifying when to give treatments to patients and how to select among multiple treatments over time are important medical problems with a few existing solutions.

In this paper, we introduce the Counterfactual Recurrent Network (CRN), a novel sequence-to-sequence model that leverages the increasingly available patient observational data to estimate treatment effects over time and answer such medical questions. To handle the bias from time-varying confounders, covariates affecting the treatment assignment policy in the observational data, CRN uses domain adversarial training to build balancing representations of the patient history. At each timestep, CRN constructs a treatment invariant representation which removes the association between patient history and treatment assignments and thus can be reliably used for making counterfactual predictions.

On a simulated model of tumour growth, with varying degree of time-dependent confounding, we show how our model achieves lower error in estimating counterfactuals and in choosing the correct treatment and timing of treatment than current state-of-the-art methods.


The lengthy trajectory and complex evolution of cancer over time means that follow-up care is a particularly important part of the patient pathway. Machine learning is particularly well-positioned to predict, prevent, and empower decision making around recurrence and relapse.

One of our key projects in this area, temporal phenotyping of disease progression, is outlined below.

Outcome-Oriented Deep Temporal Phenotyping of Disease Progression

Changhee Lee, Jem Rashbass, Mihaela Van Der Schaar

Chronic diseases evolve slowly throughout a patient’s lifetime creating heterogeneous progression patterns that make clinical outcomes remarkably varied across individual patients. A tool capable of identifying temporal phenotypes based on the patients’ different progression patterns and clinical outcomes would allow clinicians to better forecast disease progression by recognizing a group of similar past patients, and to better design treatment guidelines that are tailored to specific phenotypes.

To build such a tool, we propose a deep learning approach, which we refer to as outcome-oriented deep temporal phenotyping (ODTP), to identify temporal phenotypes of disease progression considering what type of clinical outcomes will occur and when based on the longitudinal observations. More specifically, we model clinical outcomes throughout a patient’s longitudinal observations via time-to-event (TTE) processes whose conditional intensity functions are estimated as non-linear functions using a recurrent neural network. Temporal phenotyping of disease progression is carried out by our novel loss function that is specifically designed to learn discrete latent representations that best characterize the underlying TTE processes. The key insight here is that learning such discrete representations groups progression patterns considering the similarity in expected clinical outcomes, and thus naturally provides outcome-oriented temporal phenotypes.

We demonstrate the power of ODTP by applying it to a real-world heterogeneous cohort of 11,779 stage III breast cancer patients from the UK National Cancer Registration and Analysis Service. The experiments show that ODTP identifies temporal phenotypes that are strongly associated with the future clinical outcomes and achieves significant gain on the homogeneity and heterogeneity measures over existing methods. Furthermore, we are able to identify the key driving factors that lead to transitions between phenotypes which can be translated into actionable information to support better clinical decision-making.

Further resources

The post above has introduced and explained a range of methods our lab has developed to provide actionable, accurate, and interpretable information at various points along the cancer pathway. Some of these have been integrated into a live demonstrator system based on breast cancer, fed by anonymized real-world data. More details on this project are available in the video below, taken from a presentation given by Mihaela van der Schaar and Dr. Jem Rashbass at the Royal College of Physicians in 2019.

This demonstrator (and several related pieces of work) are also introduced in an Impact Story published by The Alan Turing Institute.

For a full list of the van der Schaar Lab’s publications, click here.

Mihaela van der Schaar

Mihaela van der Schaar

Mihaela van der Schaar is the John Humphrey Plummer Professor of Machine Learning, Artificial Intelligence and Medicine at the University of Cambridge, a Fellow at The Alan Turing Institute in London, and a Chancellor’s Professor at UCLA.

Mihaela has received numerous awards, including the Oon Prize on Preventative Medicine from the University of Cambridge (2018), a National Science Foundation CAREER Award (2004), 3 IBM Faculty Awards, the IBM Exploratory Stream Analytics Innovation Award, the Philips Make a Difference Award and several best paper awards, including the IEEE Darlington Award.

In 2019, she was identified by National Endowment for Science, Technology and the Arts as the most-cited female AI researcher in the UK. She was also elected as a 2019 “Star in Computer Networking and Communications” by N²Women. Her research expertise span signal and image processing, communication networks, network science, multimedia, game theory, distributed systems, machine learning and AI.

Mihaela’s research focus is on machine learning, AI and operations research for healthcare and medicine.

Nick Maxfield

Nick Maxfield

Nick oversees the van der Schaar Lab’s communications, including media relations, content creation, and maintenance of the lab’s online presence.

Nick studied Japanese (BA Hons.) at the University of Oxford, graduating in 2012. Nick previously worked in HQ communications roles at Toyota (2013-2016) and Nissan (2016-2020).

Given his humanities/languages background and experience in communications, Nick is well-positioned to highlight and explain the real-world impact of research that can often be quite esoteric. Thankfully, he is comfortable asking almost endless questions in order to understand a topic.