The van der Schaar Lab has once again set a group record for representation at NeurIPS—widely considered the world’s largest and most prestigious AI and machine learning research conference—with a total of 14 papers accepted for publication this year.
This result represents the most papers accepted for the lab at any conference to date, and perfectly captures the diverse strengths of its small research team. The papers cover a number of the lab’s key research pillars, such as interpretable machine learning, individualized treatment effect inference, understanding and supporting decision-making (quantitative epistemology), and time series analysis, among others.
Titles, authors and abstracts for all 14 accepted papers are given below.
SurvITE: Learning Heterogeneous Treatment Effects from Time-to-Event Data
We study the problem of inferring heterogeneous treatment effects from time-to-event data. While both the related problems of (i) estimating treatment effects for binary or continuous outcomes and (ii) predicting survival outcomes have been well studied in the recent machine learning literature, their combination — albeit of high practical relevance — has received considerably less attention.
With the ultimate goal of reliably estimating the effects of treatments on instantaneous risk and survival probabilities, we focus on the problem of learning (discrete-time) treatment-specific conditional hazard functions. We find that unique challenges arise in this context due to a variety of covariate shift issues that go beyond a mere combination of well-studied confounding and censoring biases. We theoretically analyse their effects by adapting recent generalization bounds from domain adaptation and treatment effect estimation to our setting and discuss implications for model design. We use the resulting insights to propose a novel deep learning method for treatment-specific hazard estimation based on balancing representations.
We investigate performance across a range of experimental settings and empirically confirm that our method outperforms baselines by addressing covariate shifts from various sources.
On Inductive Biases for Heterogeneous Treatment Effect Estimation
We investigate how to exploit structural similarities of an individual’s potential outcomes (POs) under different treatments to obtain better estimates of conditional average treatment effects in finite samples.
Especially when it is unknown whether a treatment has an effect at all, it is natural to hypothesize that the POs are similar — yet, some existing strategies for treatment effect estimation employ regularization schemes that implicitly encourage heterogeneity even when it does not exist and fail to fully make use of shared structure.
In this paper, we investigate and compare three end-to-end learning strategies to overcome this problem — based on regularization, reparametrization and a flexible multi-task architecture — each encoding inductive bias favoring shared behavior across POs.
To build understanding of their relative strengths, we implement all strategies using neural networks and conduct a wide range of semi-synthetic experiments. We observe that all three approaches can lead to substantial improvements upon numerous baselines and gain insight into performance differences across various experimental settings.
Really Doing Great at Estimating CATE? A Critical Look at ML Benchmarking Practices in Treatment Effect Estimation
The machine learning (ML) toolbox for estimation of heterogeneous treatment effects from observational data is expanding rapidly, yet many of its algorithms have been evaluated only on a very limited set of semi-synthetic benchmark datasets.
In this paper, we investigate current benchmarking practices for ML-based conditional average treatment effect (CATE) estimators, with special focus on empirical evaluation based on the popular semi-synthetic IHDP benchmark. We identify problems with current practice and highlight that semi-synthetic benchmark datasets, which (unlike real-world benchmarks used elsewhere in ML) do not necessarily reflect properties of real data, can systematically favor some algorithms over others — a fact that is rarely acknowledged but of immense relevance for interpretation of empirical results.
Further, we argue that current evaluation metrics evaluate performance only for a small subset of possible use cases of CATE estimators, and discuss alternative metrics relevant for applications in personalized medicine.
Additionally, we discuss alternatives for current benchmark datasets, and implications of our findings for benchmarking in CATE estimation.
Estimating Multi-cause Treatment Effects via Single-cause Perturbation
Most existing methods for conditional average treatment effect estimation are designed to estimate the effect of a single cause — only one variable can be intervened on at one time.
However, many applications involve simultaneous intervention on multiple variables, which leads to multi-cause treatment effect problems. The multi-cause problem is challenging due to severe data scarcity — we only observe the outcome corresponding to the treatment that was actually given but need to infer a large number of potential outcomes under different combinations of the causes.
In this work, we propose Single-cause Perturbation (SCP), a novel two-step procedure to estimate the multi-cause treatment effect. SCP starts by augmenting the observational dataset with the estimated potential outcomes under single-cause interventions. It then performs covariate adjustment on the augmented dataset to obtain the estimator. SCP is agnostic to the exact choice of algorithm in either step.
We show formally that the procedure is valid under standard assumptions in causal inference. We demonstrate the performance gain of SCP on extensive simulation and real data experiments.
Integrating Expert ODEs into Neural ODEs: Pharmacology and Disease Progression
Modeling a system’s temporal behaviour in reaction to external stimuli is a fundamental problem in many areas. Pure Machine Learning (ML) approaches often fail in the small sample regime and cannot provide actionable insights beyond predictions. A promising modification has been to incorporate expert domain knowledge into ML models.
The application we consider is predicting the progression of disease under medications, where a plethora of domain knowledge is available from pharmacology. Pharmacological models describe the dynamics of carefully-chosen medically meaningful variables in terms of systems of Ordinary Differential Equations (ODEs). However, these models only describe a limited collection of variables, and these variables are often not observable in clinical environments.
To close this gap, we propose the latent hybridisation model (LHM) that integrates a system of expert-designed ODEs with machine-learned Neural ODEs to fully describe the dynamics of the system and to link the expert and latent variables to observable quantities.
We evaluated LHM on synthetic data as well as real-world intensive care data of COVID-19 patients. LHM consistently outperforms previous works, especially when few training samples are available such as at the beginning of the pandemic.
SyncTwin: Treatment Effect Estimation with Longitudinal Outcomes
Most of the medical observational studies estimate the causal treatment effects using electronic health records (EHR), where a patient’s covariates and outcomes are both observed longitudinally. However, previous methods focus only on adjusting for the covariates while neglecting the temporal structure in the outcomes.
To bridge the gap, this paper develops a new method, SyncTwin, that learns a patient-specific time-constant representation from the pre-treatment observations. SyncTwin issues counterfactual prediction of a target patient by constructing a synthetic twin that closely matches the target in representation. The reliability of the estimated treatment effect can be assessed by comparing the observed and synthetic pre-treatment outcomes. The medical experts can interpret the estimate by examining the most important contributing individuals to the synthetic twin.
In the real-data experiment, SyncTwin successfully reproduced the findings of a randomized controlled clinical trial using observational data, which demonstrates its usability in the complex real-world EHR.
Invariant Causal Imitation Learning for Generalizable Policies
Consider learning an imitation policy on the basis of demonstrated behavior from multiple environments, with an eye towards deployment in an unseen environment. Since the observable features from each setting may be different, directly learning individual policies as mappings from features to actions is prone to spurious correlations—and may not generalize well. However, the expert’s policy is often a function of a shared latent structure underlying those observable features that is invariant across settings.
By leveraging data from multiple environments, we propose Invariant Causal Imitation Learning (ICIL), a novel technique in which we learn a feature representation that is invariant across domains, on the basis of which we learn an imitation policy that matches expert behavior. To cope with transition dynamics mismatch, ICIL learns a shared representation of causal features (for all training environments), that is disentangled from the specific representations of noise variables (for each of those environments). Moreover, to ensure that the learned policy matches the observation distribution of the expert’s policy, ICIL estimates the energy of the expert’s observations and uses a regularization term that minimizes the imitator policy’s next state energy.
Experimentally, we compare our methods against several benchmarks in control and healthcare tasks and show its effectiveness in learning imitation policies capable of generalizing to unseen environments.
Consider learning a generative model for time-series data. The sequential setting poses a unique challenge: Not only should the generator capture the *conditional* dynamics of (stepwise) transitions, but its open-loop rollouts should also preserve the *joint* distribution of (multi-step) trajectories.
On one hand, autoregressive models trained by MLE allow learning and computing explicit transition distributions, but suffer from compounding error during rollouts. On the other hand, adversarial models based on GAN training alleviate such exposure bias, but transitions are implicit and hard to assess.
In this work, we propose a novel framework marrying the best of both worlds: Motivated by a precise moment-matching objective to mitigate compounding error, we optimize a local (but forward-looking) *transition policy*, where the reinforcement signal is provided by a global (but stepwise-decomposable) *energy model* trained by contrastive estimation. In learning, the two components are trained cooperatively, avoiding the instabilities typical of adversarial objectives. Moreover, while the learned policy serves as the generator for sampling, the learned energy naturally serves as a trajectory-level measure for evaluating sample quality. By expressly training a policy to imitate sequential behavior of time-series features in a dataset, our approach embodies “generation by imitation”.
Theoretically, we demonstrate the correctness of our formulation and consistency of our algorithm. Empirically, we evaluate its ability to generate realistic samples using real-world datasets, and verify that it performs at or above the standard of existing benchmarks.
The Medkit-Learn(ing) Environment: Medical Decision Modelling through Simulation
The goal of understanding decision-making behaviours in clinical environments is of paramount importance if we are to bring the strengths of machine learning to ultimately improve patient outcomes.
Mainstream development of algorithms is often geared towards optimal performance in tasks that do not necessarily translate well into the medical regime—due to several factors including the lack of public availability of realistic data, the intrinsically offline nature of the problem, as well as the complexity and variety of human behaviours.
We therefore present a new benchmarking suite designed specifically for medical sequential decision modelling: the Medkit-Learn(ing) Environment, a publicly available Python package providing simple and easy access to high-fidelity synthetic medical data.
While providing a standardised way to compare algorithms in a realistic medical setting, we employ a generating process that disentangles the policy and environment dynamics to allow for a range of customisations, thus enabling systematic evaluation of algorithms’ robustness against specific challenges prevalent in healthcare.
Closing the loop in medical decision support by understanding clinical decision-making: A case study on organ transplantation
Significant effort has been placed on developing decision support tools to improve patient care. However, drivers of real-world clinical decisions in complex medical scenarios are not yet well-understood, resulting in substantial gaps between these tools and practical applications.
In light of this, we highlight that more attention on understanding clinical decision-making is required both to elucidate current clinical practices and to enable effective human-machine interactions. This is imperative in high-stakes scenarios with scarce available resources. Using organ transplantation as a case study, we formalize the desiderata of methods for understanding clinical decision-making.
We show that most existing machine learning methods are insufficient to meet these requirements and propose iTransplant, a novel data-driven framework to learn the factors affecting decisions on organ offers in an instance-wise fashion directly from clinical data, as a possible solution.
Through experiments on real-world liver transplantation data from OPTN, we demonstrate the use of iTransplant to: (1) discover which criteria are most important to clinicians for organ offer acceptance; (2) identify patient-specific organ preferences of clinicians allowing automatic patient stratification; and (3) explore variations in transplantation practices between different transplant centers. Finally, we emphasize that the insights gained by iTransplant can be used to inform the development of future decision support tools.
Explaining Latent Representations with a Corpus of Examples
Modern machine learning models are complicated. Most of them rely on convoluted latent representations of their input to issue a prediction. To achieve greater transparency than a black-box that connects inputs to predictions, it is necessary to gain a deeper understanding of these latent representations.
To that aim, we propose SimplEx: a user-centred method that provides example-based explanations with reference to a freely selected set of examples, called the corpus. SimplEx uses the corpus to improve the user’s understanding of the latent space with post-hoc explanations answering two questions: (1) Which corpus examples explain the prediction issued for a given test example? (2) What features of these corpus examples are relevant for the model to relate them to the test example? SimplEx provides an answer by reconstructing the test latent representation as a mixture of corpus latent representations.
Further, we propose a novel approach, the integrated Jacobian, that allows SimplEx to make explicit the contribution of each corpus feature in the mixture. Through experiments on tasks ranging from mortality prediction to image classification, we demonstrate that these decompositions are robust and accurate.
With illustrative use cases in medicine, we show that SimplEx empowers the user by highlighting relevant patterns in the corpus that explain model representations. Moreover, we demonstrate how the freedom in choosing the corpus allows the user to have personalized explanations in terms of examples that are meaningful for them.
Current approaches for (multi-horizon) time-series forecasting using recurrent neural networks (RNNs) focus on issuing point estimates, which are insufficient for informing decision-making in critical application domains wherein uncertainty estimates are also required.
Existing methods for uncertainty quantification in RNN-based time-series forecasts are limited as they may require significant alterations to the underlying architecture, may be computationally complex, may be difficult to calibrate, may incur high sample complexity, and may not provide theoretical validity guarantees for the issued uncertainty intervals.
In this work, we extend the inductive conformal prediction framework to the time-series forecasting setup, and propose a lightweight uncertainty estimation procedure to address the above limitations. With minimal exchangeability assumptions, our approach provides uncertainty intervals with theoretical guarantees on frequentist coverage for multi-horizon forecast predictor and dataset.
We demonstrate the effectiveness of the conformal forecasting framework by comparing it with existing baselines on a variety of synthetic and real-world datasets.
MIRACLE: Causally-Aware Imputation via Learning Missing Data Mechanisms
Missing data is an important problem in machine learning practice. Starting from the premise that imputation methods should preserve the causal structure of the data, we develop a regularization scheme that encourages any baseline imputation method to be causally consistent with the underlying data generating mechanism.
Our proposal is a causally-aware imputation algorithm (MIRACLE). MIRACLE iteratively refines the imputation of a baseline by simultaneously modeling the missingness generating mechanism, encouraging imputation to be consistent with the causal structure of the data.
We conduct extensive experiments on synthetic and a variety of publicly available datasets to show that MIRACLE is able to consistently improve imputation over a variety of benchmark methods across all three missingness scenarios: at random, completely at random, and not at random.
DECAF: Generating Fair Synthetic Data Using Causally-Aware Generative Networks
Machine learning models have been criticized for reflecting unfair biases in the training data. Instead of solving for this by introducing fair learning algorithms directly, we focus on generating fair synthetic data, such that any downstream learner is fair. Generating fair synthetic data from unfair data – while remaining truthful to the underlying data-generating process (DGP) – is non-trivial.
In this paper, we introduce DECAF: a GAN-based fair synthetic data generator for tabular data. With DECAF we embed the DGP explicitly as a structural causal model in the input layers of the generator, allowing each variable to be reconstructed conditioned on its causal parents. This procedure enables inference time debiasing, where biased edges can be strategically removed for satisfying user-defined fairness requirements. The DECAF framework is versatile and compatible with several popular definitions of fairness.
In our experiments, we show that DECAF successfully removes undesired bias and – in contrast to existing methods – is capable of generating high-quality synthetic data. Furthermore, we provide theoretical guarantees on the generator’s convergence and the fairness of downstream models.
This year’s NeurIPS conference will run from December 6 through 14. Further details (including the timing of presentations by members of the van der Schaar Lab) will be provided on this page in line with announcements made by the conference’s organizers.
The conference on Neural Information Processing Systems (NeurIPS) is the largest and most prestigious conference in AI and machine learning.
The purpose of NeurIPS is to foster the exchange of research on neural information processing systems in their biological, technological, mathematical, and theoretical aspects. The core focus is peer-reviewed novel research which is presented and discussed in the general session, along with invited talks by leaders in their field.
The conference was founded in 1987 and is now a multi-track interdisciplinary annual meeting that includes invited talks, demonstrations, symposia, and oral and poster presentations of refereed papers. Along with the conference is a professional exposition focusing on machine learning in practice, a series of tutorials, and topical workshops that provide a less formal setting for the exchange of ideas.
For a full list of the van der Schaar Lab’s publications, click here.