van der Schaar Lab

Causal deep learning


On this page we introduce “causal deep learning”. Causal deep learning is our vision on how causality can improve deep learning and vice versa. We discuss shortcoming in both fields and how they may be handled with inspiration from the other.

We will first introduce what we mean exactly when talking about causal deep learning, and then provide some examples of how our lab has innovated in this area so far.

We believe that causal deep learning is a major new area of research (hence it being a research pillar of our lab). So much so, that we only depict the beginning of this important area in machine learning. Not only are we excited to continue working on this area in the future, but we are also keen to motivate researchers outside our lab to take notice and join us in this exciting topic.

This page is one of several introductions to areas that we see as “research pillars” for our lab. It is a living document, and the content here will evolve as we continue to reach out to the machine learning and healthcare communities, building a shared vision for the future of healthcare.

Our primary means of building this shared vision is through two groups of online engagement sessions: Inspiration Exchange (for machine learning students) and Revolutionizing Healthcare (for the healthcare community). If you would like to get involved, please visit the page below.

This page is authored and maintained by Jeroen Berrevoets, Zhaozhi Qian and Mihaela van der Schaar.


Introducing causal inspired deep learning

Causality is the study of cause and effect. For a long time, philosophers and scientists have been formalizing, identifying and quantifying causality in nature, even dating back to 18th century philosopher David Hume. In fact, the ability to perform causal reasoning has been recognized as a hallmark of human intelligence.

Our brains store an incredible amount of causal knowledge which, supplemented by data, we could harness to answer some of the most pressing questions of our time. More ambitiously, once we really understand the logic behind causal reasoning, we could emulate it on modern computers and create an “artificial scientist”.

Pearl & Mackenzie 2018

Although there exist several different frameworks for causal reasoning as well as some on-going philosophical debate, the most widely adopted notion of causality in computer science is given by the structural causal model (SCM), introduced in the early 20th century and now championed by Judea Pearl. A SCM is composed of three components: (1) a causal directed acyclic graph (DAG) that qualitatively describes the causal relationship between the variables (both observed as well as unobserved), i.e. if A causes B, then A –> B; (2) a set of structural equations that quantitatively specify the generative process and (3) the distribution over the variables in the SCM. 

The SCM framework allows researchers to specify causal assumptions and derive theoretical guarantees. A common assumption is that the DAG is correctly specified. Under this assumption, it has been proved that one could use observational data to answer “what if” questions about the interventional or counterfactual outcomes, which are unobserved. Yet the practitioners usually do not have access to the ground truth DAG. The field of causal discovery attempts to recover DAGs from data. However, a well-known impossibility result suggests that it is generally impossible to uniquely identify the true DAG from observational data– unless one is willing to accept strong assumptions, including parametric assumptions on the structural equations and their distributions, and assumptions on the completeness of the observed variables.

This puts the healthcare practitioners in a dilemma– while many pressing issues in healthcare could benefit from a causal perspective, the existing methods in causality require strong assumptions that many feel uncomfortable to make, illustrated in the figure below. 

Motivated by this observation, we introduce causal deep learning, our lab’s vision for causal inspired deep learning. Our goal is to develop both theory and learning algorithms that are more accurate, robust, generalizable, and fair. Importantly, we focus on properties that are empirically verifiable and are crucial for healthcare practitioners. Our perspective heavily contrasts with most of causality literature, which focuses on theoretical analysis based on strong unverifiable assumptions. Causal deep learning would also advance the deep learning literature, where many existing attempts to improve robustness and generalizability are ad-hoc and unprincipled. 

A useful analogy to understand the relation between causality and causal deep learning is to consider Learning theory and deep learning. Similar to causality, learning theory establishes theoretical guarantees on learnability based on strong assumptions such as i.i.d. samples and Lipschitz continuity. Inspired by learning theory, deep learning shifts the focus to empirical performance on real datasets, where the assumptions may not hold. We believe that the shift from theorical analysis to practical performance is a key driver to the success of deep learning over the last decade, and we intend to replicate such success in causal deep learning. Below, we will share some of our lab’s developments in this area. While this work is more practical in nature, a future focus is to provide some theoretical insights on causal deep learning.

Causal deep learning

Consider the figure below, known as “the ladder of causation” where each higher rung represents a more involved notion of inference. At the lowest rung, we find association, here we categorise most of machine learning (and thus deep learning), e.g. we associate an image of a cat with its label. At the second rung, we find intervention which is the first type of causal inference. With interventional inference we aim to predict the outcome after altering the system in which we make predictions, deviating from the training data. Lastly, we have counterfactual inference which aims to predict what would have happened if things were different, after the fact.

This ladder of causation is a useful “North Star” analogy. The higher we place a method on the ladder, the more informative its predictions become. With this analogy, it is clear that associative methods (such as most of deep learning) still have a long way to go. We believe causal deep learning moves methods higher up the ladder, between rung 1 and rung 2. As such we propose to include a “rung 1.5”. While CDL does do more than mere association (as we will explain below), it doesn’t do interventional inference in a pure causal sense as this requires many strong (and sometimes unrealistic) assumptions which we are unwilling to make. On rung 1.5 we find many interesting ML problems such as: domain adaptation, meta-learning, transfer learning, fairness, data augmentation, distributional shifts, etc. Contrasting rung 2, a rung 1.5 model should be empirically verifiable and is a “good enough” model for the task at hand.

In the remainder of this section we show how ideas from causality can make deep learning models more robust, general, and fair. These feats are achieved by realizing that causality provides a principled framework for deep learning methods to leverage expert domain knowledge. We recognize that in many problems, although the true SCM is unknown, some partial knowledge about the causal structure is available (for instance, a partial DAG is specified by domain experts). We would like to put such information to use in learning, and reasoning, by using causality as an inductive bias to deep learning methods. In doing so, causal deep learning methods will result in informative representations, which will be able to extend beyond the scope of the data, as causal knowledge remains constant across environments.

Below, we provide examples in four major areas in machine learning: supervised learning, missing data imputation, domain generalization, and fairness. The key point we will make here is that by incorporating some causal information into a model, we can greatly benefit in terms of performance, generalization, and general problem solving. Of course, causal deep learning is not limited to the examples we provide here, but these examples will go a long way in showing the versatility of causal deep learning.

Supervised Learning

Even without leveraging the interventional abilities of graphical causal models, we can greatly improve standard tasks in deep learning. Having structural equations significantly reduces the hypothesis space of a model– and thus deep learning models also. Regularization has a similar function. However, standard regularization patterns aim to simplify the models (e.g., by shrinking coefficients towards zero), rather than exploiting structure. One example of this is CAusal STructure LEarning (CASTLE) published at NeurIPS 2020.

With CASTLE, one defines a multi-objective learning target: on the one hand, we want a neural network to perform very well on a supervised main-task (for example, predicting housing prices); while on the other hand, we want the network to discover the graphical structure of the underlying data-generating process. The main idea is that the structure of the function learnt by the neural network is guided by the structure of the graphical model. With the DAG-learning objective incorporated in the network’s learning target, the hypothesis space of the network is reduced.

CASTLE: Regularization via Auxiliary Causal Graph Discovery

Trent Kyono, Yao Zhang, Mihaela van der Schaar

NeurIPS 2020

Regularization improves generalization of supervised models to out-of-sample data. Prior works have shown that prediction in the causal direction (effect from cause) results in lower testing error than the anti-causal direction. However, existing regularization methods are agnostic of causality. We introduce Causal Structure Learning (CASTLE) regularization and propose to regularize a neural network by jointly learning the causal relationships between variables. CASTLE learns the causal directed acyclical graph (DAG) as an adjacency matrix embedded in the neural network’s input layers, thereby facilitating the discovery of optimal predictors. Furthermore, CASTLE efficiently reconstructs only the features in the causal DAG that have a causal neighbor, whereas reconstruction-based regularizers suboptimally reconstruct all input features. We provide a theoretical generalization bound for our approach and conduct experiments on a plethora of synthetic and real publicly available datasets demonstrating that CASTLE consistently leads to better out-of-sample predictions as compared to other popular benchmark regularizers.

Imputation

While CASTLE is solving a problem that received a lot of attention in literature (supervised learning), let us now turn to a problem which has received considerably less attention: imputation. Despite receiving less attention than supervised learning, imputation is still a very important field in practice, as data in practice is rarely complete. Whenever there is a variable missing from a sample, we can impute it with a replacement. Naturally, the more closely the imputed variable resembles the absent variable, the better. 

Previous methods have tried estimating this missing variable from the present variables, using multiple imputation techniques. However, there is something counterintuitive to doing this. If we estimate the missing variable from the present ones, we cannot, mathematically, add additional information. Even with the missing variable replaced, it’s not like we’re adding more information to the sample. 

With causality, we take a different approach. Like CASTLE, we can model the data-generating process as a causal graphical model. The motivation behind this is that we encode some additional information into the DAG, if there is knowledge of certain links.  If, for example, we know that one variable is caused by another, and causing another, we can set that variable in a certain way such that it makes sense with respect to the parent and child variables. An example of such an approach is MIRACLE, published at NeurIPS 2021. With MIRACLE one can impute the data in such a way that the completed data remains consistent with the underlying causal structure (whether provided or learned).

MIRACLE: Causally-Aware Imputation via Learning Missing Data Mechanisms

Trent Kyono*, Yao Zhang*, Alexis Bellot, Mihaela van der Schaar

NeurIPS 2021

Missing data is an important problem in machine learning practice. Starting from the premise that imputation methods should preserve the causal structure of the data, we develop a regularization scheme that encourages any baseline imputation method to be causally consistent with the underlying data generating mechanism. Our proposal is a causally-aware imputation algorithm (MIRACLE). MIRACLE iteratively refines the imputation of a baseline by simultaneously modeling the missingness generating mechanism, encouraging imputation to be consistent with the causal structure of the data. We conduct extensive experiments on synthetic and a variety of publicly available datasets to show that MIRACLE is able to consistently improve imputation over a variety of benchmark methods across all three missingness scenarios: at random, completely at random, and not at random.

Domain generalization

Consider multiple environments, all depicting the same phenomena, such as ICU departments in different hospitals. Naturally, one could learn from each environment individually, but since they likely depict the same underlying (latent) structure, it may be beneficial to learn from all of them at once. Doing so will reduce the potential to learn spurious relations, resulting in models that generalize well to new hospitals. Of course, latent “structure” is quite vague: an ideal candidate for causality to shed some light. 

One way to generalize from one domain to another is invariant causal imitation learning (ICIL) published at NeurIPS 2021, where one assumes two sets of variables: one set that is environment invariant, and one that is environment specific. By simply assuming that both sets cause the observation at hand, while only the invariant set is shared, we have a clear definition of how to compose a representation. In ICIL, one trains a classifier that aims to recognize which environment (i.e. hospital) a sample is from. That classifier is then used to regularize the imitation learning objective in such a way that the used representation confuses this classifier. If the environment classifier is unable to predict which environment the sample is coming from, the representation is invariant across environments.

Invariant Causal Imitation Learning for Generalizable Policies

Ioana Bica, Daniel Jarrett, Mihaela van der Schaar

NeurIPS 2021

Consider learning an imitation policy on the basis of demonstrated behavior from multiple environments, with an eye towards deployment in an unseen environment. Since the observable features from each setting may be different, directly learning individual policies as mappings from features to actions is prone to spurious correlations—and may not generalize well. However, the expert’s policy is often a function of a shared latent structure underlying those observable features that is invariant across settings. By leveraging data from multiple environments, we propose Invariant Causal Imitation Learning (ICIL), a novel technique in which we learn a feature representation that is invariant across domains, on the basis of which we learn an imitation policy that matches expert behavior. To cope with transition dynamics mismatch, ICIL learns a shared representation of causal features (for all training environments), that is disentangled from the specific representations of noise variables (for each of those environments). Moreover, to ensure that the learned policy matches the observation distribution of the expert’s policy, ICIL estimates the energy of the expert’s observations and uses a regularization term that minimizes the imitator policy’s next state energy. Experimentally, we compare our methods against several benchmarks in control and healthcare tasks and show its effectiveness in learning imitation policies capable of generalizing to unseen environments.

Thanks to the assumed causal structure, ICIL performs favorably in new environments. A natural question to ask then is, what if we knew more about the environment? Specifically, in ICIL, one assumes only a very simple graphical structure: one only needs a separable set of variables in invariant and variant factors, no further structure in these sets is assumed.

We can evaluate ICIL by using pretrained agents in the OpenAI Gym environment. These pretrained agents are then considered the expert policies. Data can be collected, by using these expert policies in different domains of the OpenAI Gym. By training on some domains, and evaluating on a previously unseen domain, we can test ICIL’s out-of-distribution performance.

Let’s now consider the setting where we do know more about the invariant structure. Consider the causal assurance score (CAS) published at the IEEE transactions on AI, where one exploits the knowledge of a causal graph to evaluate how faithful a model is to the underlying structure. Naturally, a true causal structure is always invariant across datasets that capture the same variables. With CAS, we can perform model selection for unsupervised domain adaptation for predictive models.

Exploiting Causal Structure for Robust Model Selection in Unsupervised Domain Adaptation

Trent Kyono, Mihaela van der Schaar

IEEE Transactions on AI

In many real-world settings, such as healthcare, machine learning models are trained and validated on one labeled domain and tested or deployed on another where feature distributions differ, i.e., there is covariate shift. When annotations are costly or prohibitive, an unsupervised domain adaptation (UDA) regime can be leveraged requiring only unlabeled samples in the target domain. Existing UDA methods are unable to factor in a model’s predictive loss based on predictions in the target domain and therefore suboptimally leverage density ratios of only the input covariates in each domain. In this work we propose a model selection method for leveraging model predictions on a target domain without labels by exploiting the domain invariance of causal structure. We assume or learn a causal graph from the source domain, and select models that produce predicted distributions in the target domain that have the highest likelihood of fitting our causal graph. We thoroughly analyze our method under oracle knowledge using synthetic data. We then show on several real-world datasets, including several COVID-19 examples, that our method is able to improve on the state-of-the-art UDA algorithms for model selection.

Model selection for domain adaptation is also an important problem in treatment effects estimation. However, CAS is not applicable as it ignores the problem of missing counterfactuals (please find more on counterfactuals and ITE in our research pillar on treatment effects). An extension of CAS to this treatment effects setting is named interventional causal model selection (ICMS). Interestingly, ICMS solves a problem in the potential outcomes framework of causality, by leveraging the rules of do-calculus, which stem from the graphical approach to causality.

Selecting Treatment Effects Models for Domain Adaptation Using Causal Knowledge

Trent Kyono, Ioana Bica, Zhaozhi Qian, Mihaela van der Schaar

ArXiv

Selecting causal inference models for estimating individualized treatment effects (ITE) from observational data presents a unique challenge since the counterfactual outcomes are never observed. The problem is challenged further in the unsupervised domain adaptation (UDA) setting where we only have access to labeled samples in the source domain, but desire selecting a model that achieves good performance on a target domain for which only unlabeled samples are available. Existing techniques for UDA model selection are designed for the predictive setting. These methods examine discriminative density ratios between the input covariates in the source and target domain and do not factor in the model’s predictions in the target domain. Because of this, two models with identical performance on the source domain would receive the same risk score by existing methods, but in reality, have significantly different performance in the test domain. We leverage the invariance of causal structures across domains to propose a novel model selection metric specifically designed for ITE methods under the UDA setting. In particular, we propose selecting models whose predictions of interventions’ effects satisfy known causal structures in the target domain. Experimentally, our method selects ITE models that are more robust to covariate shifts on several healthcare datasets, including estimating the effect of ventilation in COVID-19 patients from different geographic locations.

Fairness

The above models are able to perform better by exploiting a causal structure. An admirable feat of course, but let’s now turn to a problem which can be defined as a causal problem: fairness. Naturally, there are many definitions of fairness (most of which are actually not defined in terms of causality), though we argue that fairness is best defined in causal terminology. A major reason for this is causality’s ability to model direction. For example, consider two variables: ethnicity and incarceration. Many studies show that people of color are much more likely to be incarcerated, i.e., there exists statistical association: incarcerated people are much more likely to be people of color. While the latter statement is indeed correct, it is not really helpful as it is only a product of the underlying problem. We must distill this statement into: because someone is a person of color, they are more likely to be incarcerated. This is a causal statement, making fairness a causal problem. 

Note that there is likely a decision maker in between ethnicity and incarceration (i.e., incarceration is mediated by the decision maker). The bias then manifests by having the direct connection between ethnicity and decision maker, rather than ethnicity to incarceration. In an unbiased system, there would be no arrow stemming from ethnicity at all.

If we wish to train a model on data on incarceration– perhaps we want to use it to provide recommendations to a judge –it will be biased against people of color. Based on ethnicity, the model will now recommend whether the accused should be incarcerated or not, which is of course unfair; ethnicity should not play a role in this decision.

If we know that decisions were made on the basis of ethnicity in the past, we can take this directional relationship into account. If we learn a generative model based on a graphical structure, we can provide it the graphical structure where ethnicity is causing incarceration during training, only to give it another graphical structure where ethnicity is not causing incarceration during sampling. In short, we provide the model an intervened graph during the generation phase. This is what DECAF (published at NeurIPS 2021) does. If we then train the incarceration-recommender on this (fair) synthetic data, the recommendations will be ethnicity-agnostic.

DECAF: Generating Fair Synthetic Data Using Causally-Aware Generative Networks

Trent Kyono*, Boris van Breugel*, Jeroen Berrevoets, Mihaela van der Schaar

NeurIPS 2021

Machine learning models have been criticized for reflecting unfair biases in the training data. Instead of solving for this by introducing fair learning algorithms directly, we focus on generating fair synthetic data, such that any downstream learner is fair. Generating fair synthetic data from unfair data – while remaining truthful to the underlying data-generating process (DGP) – is non-trivial. In this paper, we introduce DECAF: a GAN-based fair synthetic data generator for tabular data. With DECAF we embed the DGP explicitly as a structural causal model in the input layers of the generator, allowing each variable to be reconstructed conditioned on its causal parents. This procedure enables inference time debiasing, where biased edges can be strategically removed for satisfying user-defined fairness requirements. The DECAF framework is versatile and compatible with several popular definitions of fairness. In our experiments, we show that DECAF successfully removes undesired bias and – in contrast to existing methods – is capable of generating high-quality synthetic data. Furthermore, we provide theoretical guarantees on the generator’s convergence and the fairness of downstream models.

Evaluating DECAF requires us to check whether or not the data remains faithful in terms of realistic samples (as is typically the case when generating synthetic samples), while also checking whether or not the samples include unfair associations. For example, we want to compare the downstream classifier performance, when changing the protected variable, when the difference is large, the data still holds these unfair associations.

Conclusion

Summarizing causal deep learning comes down to learning deep models with more than only data. Typically, deep models are trained with a huge amount of data, but they may still fail to retain some crucial information. This is especially awkward when the information is actually quite obvious to humans. Causal deep learning lets humans encode this information into models through the notion of causal structures. 

We are excited for the future of causal deep learning. If you are too, we would like to invite you to our upcoming Inspiration Exchange session where we’ll discuss many of these ideas.  If you’d like to join us, please sign up here: https://www.vanderschaar-lab.com/engagement-sessions/inspiration-exchange/