Advancing best-in-class reinforcement learning, symbolic regression, and representation learning for tabular data, and transforming synthetic data, clinical trials, treatment effect estimation, and missing data imputation
We are very proud to consistently be represented at the world’s largest and most prestigious AI and machine learning conferences with our cutting-edge research, impactful papers, and participating in fruitful workshops as part of the conferences.
The 2023 conference season starts with Artificial Intelligence and Statistics (AISTATS), running from 25 – 27 April, one of the most prominent annual gatherings of researchers at the intersection of artificial intelligence, machine learning, statistics, and related areas. This year, the van der Schaar Lab will present 7 papers, the largest number of accepted papers we have had at AISTATS thus far. This shows the ongoing strong efforts of our team to produce increasingly impactful and creative work.
Following close behind will be the Eleventh International Conference on Learning Representations (ICLR), running from 1 – 5 May. ICLR is globally renowned for presenting and publishing cutting-edge research on all aspects of deep learning used in the fields of artificial intelligence, statistics, and data science. The van der Schaar Lab will be part of the event with 4 papers.
Collectively, these papers touch on some of the most important areas within the lab’s extensive research agenda, including reinforcement learning, symbolic regression and representation learning for tabular data, synthetic data, clinical trials, treatment effect estimation, and missing data imputation.
When a decision-maker should break a commitment
Clinical trials are costly and time intensive necessities in the development of new drugs. Committing to the process bears the inherent risk of potential failure and the loss of investment. Our team asks the question: When should a decision-maker break a commitment that is likely to fail – either to make an alternative commitment or to make no further commitments at all? In this ground-breaking paper, we formulate this question as a novel type of optimal stopping/switching problem called the optimal commitment problem (OCP), and propose a practical algorithm for solving it.
Find the paper here.
Improving symbolic regression
In this work, we address the complex problem of discovering concise closed-form mathematical equations from data by leveraging pre-trained deep generative models to capture the intrinsic regularities of equations. This novel approach fundamentally changes symbolic regression as it unifies several prominent approaches and offers a new perspective to justify and improve on the previous ad hoc designs. We lay the groundwork for superior recovery rates of true equations while also being computationally more efficient at inference time than previous solutions.
Find the paper here.
Realistic synthetic tabular data by exploiting relational structure
From computer vision to natural language processing – deep generative models have shown notable success in learning highly complex and non-linear representations to generate realistic synthetic data. So far, the complex challenges characterising tabular data, such as heterogeneous relationships, a limited number of samples, and difficulties in incorporating prior knowledge made the generation of suitable synthetic data problematic. The van der Schaar lab is the first to solve this issue by introducing GOGGLE, an end-to-end message passing scheme that jointly learns the relational structure and functional relationships as the basis of generating synthetic samples.
Find the paper here.
Shining the spotlight on competing risks in inferring heterogeneous treatment effects
Although competing risks present great practical relevance, this problem has seen very little attention in studying treatment effect estimation. We are the first to theoretically analyse and empirically illustrate when and how competing risks play a role in using generic machine learning prediction models for the estimation of heterogeneous treatment effects. In our investigation, we find that competing risks can act as an additional source of covariate shift, refocussing the lens through which treatment effect estimations should be observed.
Find the paper here.
Breaking new ground: Neural Laplace Control
Many real-world offline reinforcement learning problems involve continuous-time environments with delays. Observations are made at irregular time intervals and, in return, actions take effect with unknown delays as well. Existing offline reinforcement learning algorithms have shown to be effective with irregularly observed states in town or known delays. However, we are the first to attempt to solve the challenge of environments involving both irregular observations in time AND unknown delays. We succeed with our proposed solution: Neural Laplace Control – a continuous-time model-based offline reinforcement learning method that combines a Neural Laplace dynamics model with a model predictive control planner. In doing so, our model achieves near expert policy performance.
Find the paper here.
You can find the full list of papers for AISTATS and ICLR below.
AISTATS: 25 – 27 April 2023
Neural Laplace Control for Continuous-time Delayed Systems
Many real-world offline reinforcement learning (RL) problems involve continuous-time environments with delays. Such environments are characterised by two distinctive features: firstly, the state x(t) is observed at irregular time intervals, and secondly, the current action a(t) only affects the future state x(t + g) with an unknown delay g > 0.
A prime example of such an environment is satellite control where the communication link between earth and a satellite causes irregular observations and delays. Existing offline RL algorithms have achieved success in environments with irregularly observed states in time or known delays. However, environments involving both irregular observations in time and unknown delays remains an open and challenging problem.
To this end, we propose Neural Laplace Control, a continuous-time model-based offline RL method that combines a Neural Laplace dynamics model with a model predictive control (MPC) planner—and is able to learn from a offline dataset sampled with irregular time intervals from an environment that has a inherent unknown constant delay. We show experimentally on continuous-time delayed environments it is able to achieve near expert policy performance.
Understanding the Impact of Competing Risks on Heterogeneous Treatment Effect Estimation from Time-to-Event Data
We study the problem of inferring heterogeneous treatment effects (HTEs) from time-to-event data in the presence of competing risks. Albeit its great practical relevance, this problem has received little attention compared to its counterparts studying treatment effect estimation without time-to-event data or competing risks.
We take an outcome modeling approach to estimating HTEs, and consider how and when existing prediction models for time-to-event data can be used as plug-in estimators for potential outcome predictions. We then investigate whether competing risks present new challenges for HTE estimation – in addition to the standard confounding problem –, and find that, as there are multiple definitions of causal effects in this setting – namely total, direct and separable effects –, competing risks can act as an additional source of covariate shift depending on the desired treatment effect interpretation and associated estimand.
We theoretically analyze and empirically illustrate when and how these challenges play a role when using generic machine learning prediction models for the estimation of HTEs.
Improving adaptive conformal prediction using self-supervised learning
Conformal prediction is a powerful distribution-free tool for uncertainty quantification, establishing valid prediction intervals with finite-sample guarantees. To produce valid intervals which are also adaptive to the difficulty of each instance, a common approach is to compute normalized nonconformity scores on a separate calibration set.
Self-supervised learning has been effectively utilized in many domains to learn general representations for downstream predictors. However, the use of self-supervision beyond model pre-training and representation learning has been largely unexplored. In this work, we investigate how self-supervised pretext tasks can improve the quality of the conformal regressors, specifically by improving the adaptability of conformal intervals. We train an auxiliary model with a self-supervised pretext task on top of an existing predictive model and use the self-supervised error as an additional feature to estimate nonconformity scores.
We empirically demonstrate the benefit of the additional information using both synthetic and real data on the efficiency (width), deficit and excess of conformal prediction intervals.
To Impute or not to Impute? Missing Data in Treatment Effect Estimation
Missing data is a systemic problem in practical scenarios that causes noise and bias when estimating treatment effects. This makes treatment effect estimation from data with missingness a particularly tricky endeavour. A key reason for this is that standard assumptions on missingness are rendered insufficient due to the presence of an additional variable, treatment, besides the input (e.g. an individual) and the label (e.g. an outcome).
The treatment variable introduces additional complexity with respect to why some variables are missing that is not fully explored by previous work. In our work we introduce mixed confounded missingness (MCM), a new missingness mechanism where some missingness determines treatment selection and other missingness is determined by treatment selection. Given MCM, we show that naively imputing all data leads to poor performing treatment effects models, as the act of imputation effectively removes information necessary to provide unbiased estimates. However, no imputation at all also leads to biased estimates, as missingness determined by treatment introduces bias in covariates.
Our solution is selective imputation, where we use insights from MCM to inform precisely which variables should be imputed and which should not. We empirically demonstrate how various learners benefit from selective imputation compared to other solutions for missing data. We highlight that our experiments encompass both average treatment effects and conditional average treatment effects.
Membership Inference Attacks against Synthetic Data through Overfitting Detection
Data is the foundation of most science. Unfortunately, sharing data can be obstructed by the risk of violating data privacy, impeding research in fields like healthcare. Synthetic data is a potential solution. It aims to generate data that has the same distribution as the original data, but that does not disclose information about individuals.
Membership Inference Attacks (MIAs) are a common privacy attack, in which the attacker attempts to determine whether a particular real sample was used for training of the model. Previous works that propose MIAs against generative models either display low performance – giving the false impression that data is highly private – or need to assume access to internal generative model parameters – a relatively low-risk scenario, as the data publisher often only releases synthetic data, not the model. In this work we argue for a realistic MIA setting that assumes the attacker has some knowledge of the underlying data distribution. We propose DOMIAS, a density-based MIA model that aims to infer membership by targeting local overfitting of the generative model.
Experimentally we show that DOMIAS is significantly more successful at MIA than previous work, especially at attacking uncommon samples. The latter is disconcerting since these samples may correspond to underrepresented groups. We also demonstrate how DOMIAS’ MIA performance score provides an interpretable metric for privacy, giving data publishers a new tool for achieving the desired privacy-utility trade-off in their synthetic data.
T-Phenotype: Discovering Phenotypes of Predictive Temporal Patterns in Disease Progression
Clustering time-series data in healthcare is crucial for clinical phenotyping to understand patients’ disease progression patterns and to design treatment guidelines tailored to homogeneous patient subgroups. While rich temporal dynamics enable the discovery of potential clusters beyond static correlations, two major challenges remain outstanding: i) discovery of predictive patterns from many potential temporal correlations in the multi-variate time-series data and ii) association of individual temporal patterns to the target label distribution that best characterizes the underlying clinical progression.
To address such challenges, we develop a novel temporal clustering method, T-Phenotype, to discover phenotypes of predictive temporal patterns from labeled time-series data. We introduce an efficient representation learning approach in frequency domain that can encode variable-length, irregularly-sampled time-series into a unified representation space, which is then applied to identify various temporal patterns that potentially contribute to the target label using a new notion of path-based similarity.
Throughout the experiments on synthetic and real-world datasets, we show that T-Phenotype achieves the best phenotype discovery performance over all the evaluated baselines. We further demonstrate the utility of T-Phenotype by uncovering clinically meaningful patient subgroups characterized by unique temporal patterns.
SurvivalGAN: Generating time-to-event Data for Survival Analysis
Synthetic data is becoming an increasingly promising technology for research, successful application can improve privacy, fairness and data democratization. While there are many methods for generating synthetic tabular data, the task remains non-trivial and unexplored for specific scenarios.
One such scenario is survival data, here the key difficulty is censoring, where we don’t know the time of event or if one even occurred. Imbalance in censoring and time horizons cause generative models to experience three new failure modes specific to survival analysis: generating too few at-risk members; generating too many at-risk members and censoring too early. We formalize these failure modes and provide three new generative metrics to quantify them. Following this, we propose SurvivalGAN, a generative model that handles survival data firstly by addressing the imbalance in the censoring and time horizons, and secondly by using a dedicated mechanism for approximating time-to-event/censoring.
We evaluate this method via extensive experiments on medical datasets. SurvivalGAN outperforms multiple baselines at generating survival data, and in particular addresses the failure modes as measured by the new metrics, improving downstream performance of survival models.
ICLR: 1 – 5 May 2023
In many scenarios, decision-makers must commit to long-term actions until their
resolution before receiving the payoff of said actions, and usually, staying committed to such actions incurs continual costs. For instance, in healthcare, a newlydiscovered treatment cannot be marketed to patients until a clinical trial is conducted, which both requires time and is also costly.
Of course in such scenarios, not all commitments eventually pay off. For instance, a clinical trial might end up failing to show efficacy. Given the time pressure created by the continual cost of keeping a commitment, we aim to answer: When should a decision-maker break a commitment that is likely to fail—either to make an alternative commitment or to make no further commitments at all?
First, we formulate this question as a new type of optimal stopping/switching problem called the optimal commitment problem (OCP). Then, we theoretically analyse OCP, and based on the insight we gain, propose a practical algorithm for solving it. Finally, we empirically evaluate the performance of our algorithm in running clinical trials with subpopulation selection.
Symbolic regression (SR) aims to discover concise closed-form mathematical equations from data, a task fundamental to scientific discovery. However, the problem is highly challenging because closed-form equations lie in a complex combinatorial search space.
Existing methods, ranging from heuristic search to reinforcement learning, fail to scale with the number of input variables. We make the observation that closed-form equations often have structural characteristics and invariances (e.g., the commutative law) that could be further exploited to build more effective symbolic regression solutions. Motivated by this observation, our key contribution is to leverage pre-trained deep generative models to capture the intrinsic regularities of equations, thereby providing a solid foundation for subsequent optimization steps.
We show that our novel formalism unifies several prominent approaches of symbolic regression and offers a new perspective to justify and improve on the previous ad hoc designs, such as the usage of cross-entropy loss during pre-training. Specifically, we propose an instantiation of our framework, Deep Generative Symbolic Regression (DGSR). In our experiments, we show that DGSR achieves a higher recovery rate of true equations in the setting of a larger number of input variables, and it is more computationally efficient at inference time than state-of-the-art RL symbolic regression solutions.
TANGOS: Regularizing Tabular Neural Networks through Gradient Orthogonalization and Specialization
Despite their success with unstructured data, deep neural networks are not yet a panacea for structured tabular data. In the tabular domain, their efficiency crucially relies on various forms of regularization to prevent overfitting and provide strong generalization performance.
Existing regularization techniques include broad modelling decisions such as choice of architecture, loss functions, and optimization methods. In this work, we introduce Tabular Neural Gradient Orthogonalization and Specialization (TANGOS), a novel framework for regularization in the tabular setting built on latent unit attributions. The gradient attribution of an activation with respect to a given input feature suggests how the neuron attends to that feature, and is often employed to interpret the predictions of deep networks. In TANGOS, we take a different approach and incorporate neuron attributions directly into training to encourage orthogonalization and specialization of latent attributions in a fully-connected network. Our regularizer encourages neurons to focus on sparse, non-overlapping input features and results in a set of diverse and specialized latent units.
In the tabular domain, we demonstrate that our approach can lead to improved out-of-sample generalization performance, outperforming other popular regularization methods. We provide insight into why our regularizer is effective and demonstrate that TANGOS can be applied jointly with existing methods to achieve even greater generalization performance.
GOGGLE: Generative Modelling for Tabular Data by Learning Relational Structure
Deep generative models learn highly complex and non-linear representations to generate realistic synthetic data. While they have achieved notable success in computer vision and natural language processing, similar advances have been less demonstrable in the tabular domain.
This is partially because generative modelling of tabular data entails a particular set of challenges, including heterogeneous relationships, limited number of samples, and difficulties in incorporating prior knowledge. Additionally, unlike their counterparts in image and sequence domain, deep generative models for tabular data almost exclusively employ fully-connected layers, which encode weak inductive biases about relationships between inputs. Real-world data generating processes can often be represented using relational structures, which encode sparse, heterogeneous relationships between variables.
In this work, we learn and exploit relational structure underlying tabular data (where typical dimensionality d < 100) to better model variable dependence, and as a natural means to introduce regularization on relationships and include prior knowledge. Specifically, we introduce GOGGLE, an end-to-end message passing scheme that jointly learns the relational structure and corresponding functional relationships as the basis of generating synthetic samples. Using real-world datasets, we provide empirical evidence that the proposed method is effective in generating realistic synthetic data and exploiting domain knowledge for downstream tasks.