Published on

MetaLearners for CATE estimation

Authors

We recently published metalearners, a library for estimating Conditional Average Treatment Effects with MetaLearners in Python.

What are MetaLearners?

MetaLearners, outlined by e.g. Kuenzel et al. (2019), are a popular approach for estimating Conditional Average Treatment Effects (CATEs) in Causal Inference. The CATE quantifies the causal effect of an intervention (treatment) on an outcome of interest given subject specific features or covariates.

Let's look into an example for the sake of concreteness:

  • We have the choice of making a person ii listen to Bach or not. We consider this assignment of listening to Bach or not our treatment and say that wi=0w_i=0 if the person ii is not made to listen to Bach and wi=1w_i=1 if the person is made to listen to Bach.
  • We care about the outcome (response) yiy_i of a person, representing their joyfullness after listening or not listening to Bach, as a continuous scalar variable.
  • We have a corpus of data with triplets (xi,wi,yi)(x_i, w_i, y_i) where xix_i captures some features (covariates) of the person, such as their age or whether they are a musician or not.
  • We would like to learn from this corpus which persons do enjoy listening to Bach and, therefore, to which persons we should play Bach, based on their covariates. We can formalize this with the notion of CATEs. If the CATE of playing Bach, given age and being a musician or not, is positive, we'll want to play it. In other words, we can define a policy based on CATE estimates.

Estimating CATEs is not a regular prediction problem since it relies on interventions and seeks to make causal statements.

Moreover, CATE learning is particularly challenging due to the fundamental problem of Causal Inference, i.e. that we observe for a person ii either the outcome under treatment or the outcome without treatment, but never both. For more details, please see the background section of the metalearners documentation.

MetaLearners are one family of approaches for estimating CATEs. In particular, they decompose the overall estimation problem into several regular prediction, i.e. classification or regression, problems. These prediction problems can then be tackled with arbitrary models of choice, for instance decision tree ensembles from lightgbm, generalized linear models from glum or anything from scikit-learn.

We can see this modular fashion in the following illustration

metalearner

where the shaded area represents the MetaLearner and the white triangles individual classifiers or regressors employed by the MetaLearner.

Furthermore, we can gauge that the input to the MetaLearner is data D=(W,X,Y)\mathcal{D}=(W, X, Y). Akin to prediction estimators, it consumes XX, feature/covariates values per unit, as well as YY, observed outcomes per unit. Unlike prediction estimators, it also expects WW, an observed treatment variant assignment per unit. This data can stem from an experiment, such as a Randomized Control Trial, or, under certain technical assumptions, from observational data. For more details on this, see this section of the metalearners documentation.

The MetaLearner's output can either be thought of as function of CATEs τ^()\hat{\tau}(\cdot) or point CATE estimates for given covariates τ^(X)\hat{\tau}(X).

Importantly, the MetaLearner can then either be used for 'in-sample' data, i.e. data it has seen during training, or 'out-of-sample' data, i.e. data it hasn't seen during training. The latter case resembles a typical Machine Learning inference setting.

Why estimate CATEs?

While there are various interesting use cases for CATE estimation, the one we'd like to emphasize is learning a policy.

Let's assume we trained a MetaLearner. In other words, we now have access to an approximate CATE function τ^\hat{\tau} which we can now apply on unseen covariates/features.

Using the estimated CATE model we can assign treatment according to the optimal policy:

π^(xi)={1 if τ^(xi)>00 if τ^(xi)0\hat{\pi}(x_i) = \begin{cases} 1 & \text{ if } \hat{\tau}(x_i) > 0 \\ 0 & \text{ if } \hat{\tau}(x_i) \leq 0 \end{cases}

where we assume (without loss of generality) that a more positive outcome is preferred. Thanks to the underlying CATE MetaLearner τ^\hat{\tau} this optimal policy can easily be applied to new unseen data.

Tying this back to our example from before, we learned a function τ^\hat{\tau} based on experiment data, which tells us what the expected change in joyfullness yiy_i is if we do play Bach to a person with specific covariates xix_i. Now, we can use this estimation for a simple policy rule: if the estimate says that it'll have a positive effect to play Bach, we'll play Bach. If not, we won't.

The definition of such a policy can further be adapted for scenarios where different treatment variants have different costs or when there is a fixed budget per treatment variant - we can always just pick the units with highest or lowest CATE estimates.

Why did we build metalearners?

From a performance point of view, MetaLearners not only have desirable theoretical properties, see e.g. Kennedy (2020) or Nie et al. (2019), but also fare well empiricially, see e.g. Caron et al. (2022).

From a usage point of view, MetaLearners are desirable since they decompose the problem of CATE estimation into several regular prediction, i.e. classification or regression, problems. This is not only conceptually elegant but also of great practical use: one can reuse mature and battle-tested prediction software, e.g. lightgbm or glum, for these decomposed sub-problems. Since we think very much in an end-to-end and process integration perspective, the maturity and reliability of our modelling approach is of utmost importance to us.

Neither did we invent the method of MetaLearners, nor do we provide the first implementation thereof. Concretely, causalml and econml are pre-existing open-source Python implementations of some MetaLearners.

Having relied on both for a while, we came to realize that our use cases have requirements which aren't met by pre-existing implementations.

On a high level, we want to be able to

  • use the native treatment of categorical features in tree models see e.g. LightGBM's categorical feature support via pandas' category' data type dtype
  • access, evaluate and extract component models of a MetaLearner
  • insert and reuse previously trained component models in a MetaLearner
  • parallelize the estimation of a MetaLearner's component models

You can find the detailed description of new requirements - which have been met in metalearners - in our documentation.

Moreover, we gave a talk on our previously unmet requirements at PyData Amsterdam.

How can one use metalearners?

The metalearners library can be found on PyPI and conda-forge. Hence, it can be installed either via pip

$ pip install metalearners

or via conda

$ conda install metalearners -c conda-forge

Once installed, the interface is very similar to the protocol one is used to from scikit-learn estimators.

If the data described above is found in a DataFrame df, CATE estimates can be produced as simply as this:

from metalearners import TLearner
from lightgbm import LGBMRegressor

tlearner = TLearner(
    # Choose any regressor that you deem fit.
    nuisance_model_factory=LGBMRegressor,
    # The outcome Y is a continuous scalar.
    is_classification=False,
    # Playing bach or not playing Bach.
    n_variants=2,
)

tlearner.fit(
    X=df[feature_columns],
    y=df[outcome_column],
    w=df[treatment_column],
)

cate_estimates = tlearner.predict(
  X=df[feature_columns],
  is_oos=False,
)

You can find plenty of more in-depth, self-contained and re-runnable examples with explanations in our examples section of our documentation.