The van der Schaar Lab has set a new group record for representation at NeurIPS 2020—widely considered the world’s largest AI and machine learning research conference—with a total of 8 papers accepted for presentation.
This is an unprecedented achievement for the lab, and demonstrates the diverse strengths of its small research team. The papers cover diverse topics, such as interpretability, uncertainty quantification, causal inference, and imitation learning. Applications in healthcare are similarly broad, ranging from treatment effect estimation to predicting the impact of COVID-19 spread prevention policies.
Titles, authors and abstracts for all 8 selected papers are given below.
When and How to Lift the Lockdown? Global COVID-19 Scenario Analysis and Policy Assessment using Compartmental Gaussian Processes
The coronavirus disease 2019 (COVID-19) global pandemic has led many countries to impose unprecedented lockdown measures in order to slow down the outbreak.
Questions on whether governments have acted promptly enough, and whether lockdown measures can be lifted soon have since been central in public discourse. Data-driven models that predict COVID-19 fatalities under different lockdown policy scenarios are essential for addressing these questions, and for informing governments on future policy directions.
To this end, this paper develops a Bayesian model for predicting the effects of COVID-19 containment policies in a global context — we treat each country as a distinct data point, and exploit variations of policies across countries to learn country-specific policy effects.
Our model utilizes a two-layer Gaussian process (GP) prior — the lower layer uses a compartmental SEIR (Susceptible, Exposed, Infected, Recovered) model as a prior mean function with “country-and-policy-specific” parameters that capture fatality curves under different “counterfactual” policies within each country, whereas the upper layer is shared across all countries, and learns lower-layer SEIR parameters as a function of country features and policy indicators.
Our model combines the solid mechanistic foundations of SEIR models (Bayesian priors) with the flexible data-driven modeling and gradient-based optimization routines of machine learning (Bayesian posteriors) — i.e., the entire model is trained end-to-end via stochastic variational inference.
We compare the projections of our model with other models listed by the Center for Disease Control (CDC), and provide scenario analyses for various lockdown and reopening strategies highlighting their impact on COVID-19 fatalities.
Robust Recursive Partitioning for Heterogeneous Treatment Effects with Uncertainty Quantification
Subgroup analysis of treatment effects plays an important role in applications from medicine to public policy to recommender systems. It allows physicians (for example) to identify groups of patients for whom a given drug or treatment is likely to be effective and groups of patients for which it is not.
Most of the current methods of subgroup analysis begin with a particular algorithm for estimating individualized treatment effects (ITE) and identify subgroups by maximizing the difference across subgroups of the average treatment effect in each subgroup. These approaches have several weaknesses: they rely on a particular algorithm for estimating ITE, they ignore (in)homogeneity within identified subgroups, and they do not produce good confidence estimates.
This paper develops a new method for subgroup analysis, R2P, that addresses all these weaknesses. R2P uses an arbitrary, exogenously prescribed algorithm for estimating ITE and quantifies the uncertainty of the ITE estimation, using a construction that is more robust than other methods.
Experiments using synthetic and semi-synthetic datasets (based on real data) demonstrate that R2P constructs partitions that are simultaneously more homogeneous within groups and more heterogeneous across groups than the partitions produced by other methods.
Moreover, because R2P can employ any ITE estimator, it also produces much narrower confidence intervals with a prescribed coverage guarantee than other methods.
VIME: Extending the Success of Self- and Semi-supervised Learning to Tabular Domain
Self- and semi-supervised learning frameworks have made significant progress in training machine learning models with limited labeled data in image and language domains. These methods heavily rely on the unique structure in the domain datasets (such as spatial relationships in images or semantic relationships in language). They are not adaptable to general tabular data which does not have the same explicit structure as image and language data.
In this paper, we fill this gap by proposing novel self- and semi-supervised learning frameworks for tabular data, which we refer to collectively as VIME (Value Imputation and Mask Estimation). We create a novel pretext task of estimating mask vectors from corrupted tabular data in addition to the reconstruction pretext task for self-supervised learning.
We also introduce a novel tabular data augmentation method for self- and semi-supervised learning frameworks. In experiments, we evaluate the proposed framework in multiple tabular datasets from various application domains, such as genomics and clinical data.
VIME exceeds state-of-the-art performance in comparison to the existing baseline methods.
CASTLE: Regularization via Auxiliary Causal Graph Discovery
Regularization improves generalization of supervised models to out-of-sample data. Prior works have shown that prediction in the causal direction (effect from cause) results in lower testing error than the anti-causal direction. However, existing regularization methods are agnostic of causality.
We introduce Causal Structure Learning (CASTLE) regularization and propose to regularize a neural network by jointly learning the causal relationships between variables. CASTLE learns the causal directed acyclical graph (DAG) as an adjacency matrix embedded in the neural network’s input layers, thereby facilitating the discovery of optimal predictors.
Furthermore, CASTLE efficiently reconstructs only the features in the causal DAG that have a causal neighbor, whereas reconstruction-based regularizers suboptimally reconstruct all input features. We provide a theoretical generalization bound for our approach and conduct experiments on a plethora of synthetic and real publicly available datasets demonstrating that CASTLE consistently leads to better out-of-sample predictions as compared to other popular benchmark regularizers.
Learning outside the Black-Box: The pursuit of interpretable models
Machine learning has proved its ability to produce accurate models — but the deployment of these models outside the machine learning community has been hindered by the difficulties of interpreting these models.
This paper proposes an algorithm that produces a continuous global interpretation of any given continuous black-box function. Our algorithm employs a variation of projection pursuit in which the ridge functions are chosen to be Meijer G-functions, rather than the usual polynomial splines. Because Meijer G-functions are differentiable in their parameters, we can “tune” the parameters of the representation by gradient descent; as a consequence, our algorithm is efficient.
Using five familiar data sets from the UCI repository and two familiar machine learning algorithms, we demonstrate that our algorithm produces global interpretations that are both faithful (highly accurate) and parsimonious (involve a small number of terms). Our interpretations permit easy understanding of the relative importance of features and feature interactions.
Our interpretation algorithm represents a leap forward from the previous state of the art.
Estimating the Effects of Continuous-valued Interventions using Generative Adversarial Networks
While much attention has been given to the problem of estimating the effect of discrete interventions from observational data, relatively little work has been done in the setting of continuous-valued interventions, such as treatments associated with a dosage parameter.
In this paper, we tackle this problem by building on a modification of the generative adversarial networks (GANs) framework. Our model, SCIGAN, is flexible and capable of simultaneously estimating counterfactual outcomes for several different continuous interventions.
The key idea is to use a significantly modified GAN model to learn to generate counterfactual outcomes, which can then be used to learn an inference model, using standard supervised methods, capable of estimating these counterfactuals for a new sample. To address the challenges presented by shifting to continuous interventions, we propose a novel architecture for our discriminator – we build a hierarchical discriminator that leverages the structure of the continuous intervention setting. Moreover, we provide theoretical results to support our use of the GAN framework and of the hierarchical discriminator.
In the experiments section, we introduce a new semi-synthetic data simulation for use in the continuous intervention setting and demonstrate improvements over the existing benchmark models.
Deciding how to optimally treat a patient, including how to select treatments over time among the multiple available treatments, represents one of the most important issues that need to be addressed in medicine today. A dynamic treatment regime (DTR) is a sequence of treatment rules indicating how to individualize treatments for a patient based on the previously assigned treatments and the evolving covariate history. However, DTR evaluation and learning based on offline data remain challenging problems due to the bias introduced by time-varying confounders that affect treatment assignment over time; this may lead to suboptimal treatment rules being used in practice.
In this paper, we introduce Gradient Regularized V-learning (GRV), a novel method for estimating the value function of a DTR. GRV regularizes the underlying outcome and propensity score models with respect to the optimality condition in semiparametric estimation theory. On the basis of this design, we construct estimators that are efficient and have low variance in finite samples regime.
Using two synthetic datasets and one real-world medical dataset, we demonstrate that our method is superior to existing baseline methods in estimating value functions and optimizing DTRs, thereby providing significantly improved treatment options over time for patients.
Strictly Batch Imitation Learning by Energy-based Distribution Matching
Consider learning a policy purely on the basis of demonstrated behavior—that is, with no access to reinforcement signals, no knowledge of transition dynamics, and no further interaction with the environment. This strictly batch imitation learning problem arises wherever live experimentation is costly, such as in healthcare.
One solution is simply to retrofit existing algorithms for apprenticeship learning to work in the offline setting. But such an approach bargains heavily on model estimation or off-policy evaluation, and can be indirect and inefficient. We argue that a good solution should be able to explicitly parameterize a policy (i.e. respecting action conditionals), implicitly account for rollout dynamics (i.e. respecting state marginals), and—crucially—operate in an entirely offline fashion.
To meet this challenge, we propose a novel technique by energy-based distribution matching (EDM): By identifying parameterizations of the (discriminative) model of a policy with the (generative) energy function for state distributions, EDM provides a simple and effective solution that equivalently minimizes a divergence between the occupancy measures of the demonstrator and the imitator.
Through experiments with application to control tasks and healthcare settings, we illustrate consistent performance gains over existing algorithms for strictly batch imitation learning.
For a full list of the van der Schaar Lab’s publications, click here.