Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

carlg/alternative-metrics #950

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
24c8a2f
Use two static methods to wrap alternative scoring formulas
carl-offerfit Jan 27, 2025
4d15237
Arg switch
carl-offerfit Jan 27, 2025
c6337cc
Do a better job of preserving the original
carl-offerfit Jan 27, 2025
3d4686d
Pass the scoring parameter through the hierarchy
carl-offerfit Jan 28, 2025
d9239a1
Add a public function to score the nuisance models on arbitrary data …
carl-offerfit Jan 28, 2025
cd007f7
Label nuisance score outputs
carl-offerfit Jan 28, 2025
f2958b7
Check if the nuisance scorer is a standard name
carl-offerfit Jan 28, 2025
8f38d8a
Better doc string comments
carl-offerfit Jan 28, 2025
0b8f497
Add the option to score the first stage treatment model by dimension
carl-offerfit Jan 29, 2025
2792545
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2025
84b8dac
Pre-commit fixes
carl-offerfit Jan 29, 2025
cb10a24
Merge branch 'carl/metrics' of github.com:carl-offerfit/EconML into c…
carl-offerfit Jan 29, 2025
699dd64
Merge branch 'main' into carl/metrics
carl-offerfit Jan 29, 2025
4c3077c
Undo the ruff changes made by mistake
carl-offerfit Jan 29, 2025
ae8393d
Merge branch 'carl/metrics' of github.com:carl-offerfit/EconML into c…
carl-offerfit Jan 29, 2025
d33a6a1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2025
d092529
Just proper pre-commit fixes on _relearn, no reformatting
carl-offerfit Jan 29, 2025
a34355a
DML fixes, no reformatting.
carl-offerfit Jan 29, 2025
c6df717
Merge branch 'carl/metrics' of github.com:carl-offerfit/EconML into c…
carl-offerfit Jan 29, 2025
c717b78
Gets some basic tests working
carl-offerfit Feb 18, 2025
5f02270
Add binary metric functions
carl-offerfit Feb 18, 2025
d5f0fe4
Fix formatting
carl-offerfit Feb 18, 2025
148d3f9
More format fixes
carl-offerfit Feb 18, 2025
35f104f
Handle one hot encoding within score_nuisances
carl-offerfit Feb 18, 2025
d10a70d
Adjust test for dummies created in score_nuisances
carl-offerfit Feb 18, 2025
c411631
Reset the name - when not debugging, it does not go first
carl-offerfit Feb 18, 2025
3abd752
Move score nuisance from _ortho_learner to _rlearner - it really depe…
carl-offerfit Feb 20, 2025
39cc39d
Add logic so that _ortho_learner raises an exception if non-standard …
carl-offerfit Feb 20, 2025
540da71
Add default values in _ModelNuisance.score
carl-offerfit Feb 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,7 @@ def effect_inference(self, X=None, *, T0=0, T1=1):

effect_inference.__doc__ = LinearCateEstimator.effect_inference.__doc__

def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, scoring=None):
"""
Score the fitted CATE model on a new data set.

Expand Down Expand Up @@ -1055,6 +1055,9 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
Weights for each samples
groups: (n,) vector, optional
All rows corresponding to the same group will be kept together during splitting.
scoring: name of an sklearn scoring function to use instead of the default, optional
Supports f1_score, log_loss, mean_absolute_error, mean_squared_error, r2_score,
and roc_auc_score.

Returns
-------
Expand Down Expand Up @@ -1113,9 +1116,24 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):

accumulated_nuisances += nuisances

score_kwargs = {
'X': X,
'W': W,
'Z': Z,
'sample_weight': sample_weight,
'groups': groups
}
# If using an _rlearner, the scoring parameter can be passed along, if provided
if scoring is not None:
# Cannot import in header, or circular imports
from .dml._rlearner import _ModelFinal
if isinstance(self._ortho_learner_model_final, _ModelFinal):
score_kwargs['scoring'] = scoring
else:
raise NotImplementedError("scoring parameter only implemented for "
"_rlearner._ModelFinal")
return self._ortho_learner_model_final.score(Y, T, nuisances=accumulated_nuisances,
**filter_none_kwargs(X=X, W=W, Z=Z,
sample_weight=sample_weight, groups=groups))
**filter_none_kwargs(**score_kwargs))

@property
def ortho_learner_model_final_(self):
Expand Down
132 changes: 122 additions & 10 deletions econml/dml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,16 @@

from abc import abstractmethod
import numpy as np

import pandas as pd
from sklearn.metrics import (
f1_score,
log_loss,
mean_absolute_error,
mean_squared_error,
r2_score,
roc_auc_score
)
from scipy.stats import pearsonr
from ..sklearn_extensions.model_selection import ModelSelector
from ..utilities import (filter_none_kwargs)
from .._ortho_learner import _OrthoLearner
Expand All @@ -54,10 +63,13 @@ def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight
filter_none_kwargs(sample_weight=sample_weight, groups=groups))
return self

def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None,
y_scoring='mean_squared_error', t_scoring='mean_squared_error', t_score_by_dim=False):
# note that groups are not passed to score because they are only used for fitting
T_score = self._model_t.score(X, W, T, **filter_none_kwargs(sample_weight=sample_weight))
Y_score = self._model_y.score(X, W, Y, **filter_none_kwargs(sample_weight=sample_weight))
T_score = self._model_t.score(X, W, T, **filter_none_kwargs(sample_weight=sample_weight),
scoring=t_scoring, score_by_dim=t_score_by_dim)
Y_score = self._model_y.score(X, W, Y, **filter_none_kwargs(sample_weight=sample_weight),
scoring=y_scoring)
return Y_score, T_score

def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
Expand Down Expand Up @@ -98,18 +110,60 @@ def fit(self, Y, T, X=None, W=None, Z=None, nuisances=None,
def predict(self, X=None):
return self._model_final.predict(X)

def score(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, groups=None):
def score(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, groups=None,
scoring='mean_squared_error'):
Y_res, T_res = nuisances
if Y_res.ndim == 1:
Y_res = Y_res.reshape((-1, 1))
if T_res.ndim == 1:
T_res = T_res.reshape((-1, 1))
effects = self._model_final.predict(X).reshape((-1, Y_res.shape[1], T_res.shape[1]))
Y_res_pred = np.einsum('ijk,ik->ij', effects, T_res).reshape(Y_res.shape)
if sample_weight is not None:
return np.mean(np.average((Y_res - Y_res_pred) ** 2, weights=sample_weight, axis=0))
return _ModelFinal._wrap_scoring(Y_true=Y_res, Y_pred=Y_res_pred, scoring=scoring, sample_weight=sample_weight)

@staticmethod
def _wrap_scoring(scoring, Y_true, Y_pred, sample_weight=None):
"""
Wrap the option to call several sklearn scoring functions that accept sample weighting.

Unfortunately there is no utility like get_scorer that is both generic and supports
samples weights.
"""
if scoring == 'f1':
return f1_score(Y_true, Y_pred, sample_weight=sample_weight)
elif scoring == 'mean_absolute_error':
return mean_absolute_error(Y_true, Y_pred, sample_weight=sample_weight)
elif scoring == 'mean_squared_error':
return mean_squared_error(Y_true, Y_pred, sample_weight=sample_weight)
elif scoring == 'r2':
return r2_score(Y_true, Y_pred, sample_weight=sample_weight)
elif scoring == 'roc_auc':
return roc_auc_score(Y_true, Y_pred, sample_weight=sample_weight)
elif scoring == 'log_loss':
return log_loss(Y_true, Y_pred, sample_weight=sample_weight)
elif scoring == 'pearsonr':
if sample_weight is not None:
raise NotImplementedError("pearsonr score does not support sample weighting")
return pearsonr(Y_true, Y_pred)
else:
return np.mean((Y_res - Y_res_pred) ** 2)
raise NotImplementedError(f"wrap_weighted_scoring does not support '{scoring}'" )

@staticmethod
def wrap_scoring(scoring, Y_true, Y_pred, sample_weight=None, score_by_dim=False):
"""
In case the caller wants a score for each dimension of a multiple treatment model.

Loop over the call to the single score wrapper.
"""
if not score_by_dim:
return _ModelFinal._wrap_scoring(scoring, Y_true, Y_pred, sample_weight)
else:
assert Y_true.shape == Y_pred.shape, "Mismatch shape in wrap_scoring"
n_out = Y_pred.shape[1]
res = [None]*Y_pred.shape[1]
for yidx in range(n_out):
res[yidx]= _ModelFinal.wrap_scoring(scoring, Y_true[:,yidx], Y_pred[:,yidx], sample_weight)
return res


class _RLearner(_OrthoLearner):
Expand Down Expand Up @@ -422,7 +476,7 @@ def fit(self, Y, T, *, X=None, W=None, sample_weight=None, freq_weight=None, sam
cache_values=cache_values,
inference=inference)

def score(self, Y, T, X=None, W=None, sample_weight=None):
def score(self, Y, T, X=None, W=None, sample_weight=None, scoring=None):
"""
Score the fitted CATE model on a new data set.

Expand Down Expand Up @@ -453,7 +507,7 @@ def score(self, Y, T, X=None, W=None, sample_weight=None):
The MSE of the final CATE model on the new data.
"""
# Replacing score from _OrthoLearner, to enforce Z=None and improve the docstring
return super().score(Y, T, X=X, W=W, sample_weight=sample_weight)
return super().score(Y, T, X=X, W=W, sample_weight=sample_weight, scoring=scoring)

@property
def rlearner_model_final_(self):
Expand Down Expand Up @@ -493,3 +547,61 @@ def residuals_(self):
"Set to `True` to enable residual storage.")
Y_res, T_res = self._cached_values.nuisances
return Y_res, T_res, self._cached_values.X, self._cached_values.W


def score_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None, y_scoring=None,
t_scoring=None, t_score_by_dim=False):
"""
Score the fitted nuisance models on arbitrary data and using any supported sklearn scoring.

Supported scorings depend on whether or not sample weights are sued: If no sample
weights are used, then any of those provided by sklearn.metrics.get_scorer_names() are
available, as well as non-negated versions of mean_squared_error, mean_absolute_error.
If sample weights are used, then supported scorings are f1_score, log_loss,
mean_absolute_error, mean_squared_error, r2_score, roc_auc_score.

Parameters
----------
Y: (n, d_y) matrix or vector of length n
Outcomes for each sample
T: (n, d_t) matrix or vector of length n
Treatments for each sample
X: (n, d_x) matrix, optional
Features for each sample
W: (n, d_w) matrix, optional
Controls for each sample
Z: (n, d_z) matrix, optional
Instruments for each sample
sample_weight:(n,) vector, optional
Weights for each samples
t_scoring: str, optional
Name of an sklearn scoring function to use instead of the default for model_t
y_scoring: str, optional
Name of an sklearn scoring function to use instead of the default for model_y
t_score_by_dim: bool, default=False
Score prediction of treatment dimensions separately

Returns
-------
score_dict : dict[str,list[float]]
A dictionary where the keys indicate the Y and T scores used and the values are
lists of scores, one per CV fold model.
"""
Y_key = 'Y_defscore' if not y_scoring else f"Y_{y_scoring}"
T_Key = 'T_defscore' if not t_scoring else f"T_{t_scoring}"
score_dict = {
Y_key : [],
T_Key : []
}

# For discrete treatments, these will have to be one hot encoded
Y_2_score = pd.get_dummies(Y) if self.discrete_outcome and (len(Y.shape) == 1 or Y.shape[1] == 1) else Y
T_2_score = pd.get_dummies(T) if self.discrete_treatment and (len(T.shape) == 1 or T.shape[1] == 1) else T

for m in self._models_nuisance[0]:
Y_score, T_score = m.score(Y_2_score, T_2_score, X=X, W=W, Z=Z, sample_weight=sample_weight,
y_scoring=y_scoring, t_scoring=t_scoring,
t_score_by_dim=t_score_by_dim)
score_dict[Y_key].append(Y_score)
score_dict[T_Key].append(T_score)
return score_dict
42 changes: 31 additions & 11 deletions econml/dml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import (FunctionTransformer)
from sklearn.utils import check_random_state
from sklearn.metrics import get_scorer, get_scorer_names


from .._ortho_learner import _OrthoLearner
from ._rlearner import _RLearner
from ._rlearner import _RLearner, _ModelFinal
from .._cate_estimator import (DebiasedLassoCateEstimatorMixin,
LinearModelFinalCateEstimatorMixin,
StatsModelsCateEstimatorMixin,
Expand Down Expand Up @@ -52,20 +54,38 @@ def predict(self, X, W):
raise AttributeError("Cannot use a classifier as a first stage model when the target is continuous!")
return self._model.predict(_combine(X, W, n_samples))

def score(self, X, W, Target, sample_weight=None):
if hasattr(self._model, 'score'):
if self._discrete_target:
# In this case, the Target is the one-hot-encoding of the treatment variable
# We need to go back to the label representation of the one-hot so as to call
# the classifier.
Target = inverse_onehot(Target)
def score(self, X, W, Target, sample_weight=None, scoring=None, score_by_dim=False):
XW_combined = _combine(X, W, Target.shape[0])
if self._discrete_target:
# In this case, the Target is the one-hot-encoding of the treatment variable
# We need to go back to the label representation of the one-hot so as to call
# the classifier.
Target = inverse_onehot(Target)
if hasattr(self._model, 'score') and scoring is None and not score_by_dim:
# Standard default model scoring
if sample_weight is not None:
return self._model.score(_combine(X, W, Target.shape[0]), Target, sample_weight=sample_weight)
return self._model.score(XW_combined, Target, sample_weight=sample_weight)
else:
return self._model.score(_combine(X, W, Target.shape[0]), Target)
return self._model.score(XW_combined, Target)
else:
return None
return _FirstStageWrapper._wrap_scoring(scoring,Y_true=Target, X=XW_combined, est=self._model,
sample_weight=sample_weight, score_by_dim=score_by_dim)

@staticmethod
def _wrap_scoring(scoring, Y_true, X, est, sample_weight=None, score_by_dim=False):
"""
Wrap the alternative scoring functions get_scorer and _ModelFinal.wrap_scoring.

If there are no weights, use the get_scorer functionality to support ANY sklearn
evaluation metrics. Otherwise, use the static class method from _ModelFinal that supports
weights. That version takes the estimates, not the estimator.
"""
if sample_weight is None and not score_by_dim and scoring in get_scorer_names():
scorer = get_scorer(scoring)
return scorer(est, X, Y_true)
else:
Y_pred = est.predict(X)
return _ModelFinal.wrap_scoring(scoring, Y_true, Y_pred, sample_weight, score_by_dim=score_by_dim)

class _FirstStageSelector(SingleModelSelector):
def __init__(self, model: SingleModelSelector, discrete_target):
Expand Down
75 changes: 75 additions & 0 deletions econml/tests/test_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,81 @@ def true_fn(x):
np.testing.assert_array_less(lb - .01, truth)
np.testing.assert_array_less(truth, ub + .01)

def test_forest_dml_score_fns(self):
np.random.seed(1234)
n = 20000 # number of raw samples
d = 10

Z = np.random.binomial(1, .5, size=(n, d))
T = np.random.binomial(1, .5, size=(n,))

def true_fn(x):
return -1 + 2 * x[:, 0] + x[:, 1] * x[:, 2]

y = true_fn(Z) * T + Z[:, 0] + (1 * Z[:, 0] + 1) * np.random.normal(0, 1, size=(n,))
X = Z[:, :4]
W = Z[:, 4:]

est = CausalForestDML(model_y=GradientBoostingRegressor(n_estimators=30, min_samples_leaf=30),
model_t=GradientBoostingClassifier(n_estimators=30, min_samples_leaf=30),
discrete_treatment=True,
cv=2,
n_jobs=None,
n_estimators=1000,
max_samples=.4,
min_samples_leaf=10,
min_impurity_decrease=0.001,
verbose=0, min_var_fraction_leaf=.1,
fit_intercept=False,
random_state=12345)

est.fit(y, T, X=X, W=W)

s1 = est.score(Y=y,T=T,X=X, W=W, scoring='mean_squared_error')
s2 = est.score(Y=y,T=T,X=X, W=W)
assert s1 == s2
np.testing.assert_allclose(s1, 2.50, rtol=0, atol=.01)
s3 = est.score(Y=y, T=T, X=X, W=W, scoring='mean_absolute_error')
np.testing.assert_allclose(s3, 1.19, rtol=0, atol=.01)
s4 = est.score(Y=y, T=T, X=X, W=W, scoring='r2')
np.testing.assert_allclose(s4, 0.113, rtol=0, atol=.001)
s5 = est.score(Y=y, T=T, X=X, W=W, scoring='pearsonr')
np.testing.assert_allclose(s5[0], 0.337, rtol=0, atol=0.005 )

sn1 = est.score_nuisances(Y=y, T=T, X=X, W=W,
t_scoring='mean_squared_error',
y_scoring='mean_squared_error')
np.testing.assert_allclose(sn1['Y_mean_squared_error'], [2.8,2.8], rtol=0, atol=.1)
np.testing.assert_allclose(sn1['T_mean_squared_error'], [1.5,1.5], rtol=0, atol=.1)

sn2 = est.score_nuisances(Y=y, T=T, X=X, W=W,
t_scoring='mean_absolute_error',
y_scoring='mean_absolute_error')
np.testing.assert_allclose(sn2['Y_mean_absolute_error'], [1.3,1.3], rtol=0, atol=.1)
np.testing.assert_allclose(sn2['T_mean_absolute_error'], [1.0,1.0], rtol=0, atol=.1)

sn3 = est.score_nuisances(Y=y, T=T, X=X, W=W,
t_scoring='r2',
y_scoring='r2')
np.testing.assert_allclose(sn3['Y_r2'], [0.27,0.27], rtol=0, atol=.005)
np.testing.assert_allclose(sn3['T_r2'], [-5.1,-5.1], rtol=0, atol=0.25)

sn4 = est.score_nuisances(Y=y, T=T, X=X, W=W,
t_scoring='pearsonr',
y_scoring='pearsonr')
# Ignoring the p-values returned with the score
y_pearsonr = [s[0] for s in sn4['Y_pearsonr']]
t_pearsonr = [s[0] for s in sn4['T_pearsonr']]
np.testing.assert_allclose(y_pearsonr, [0.52, 0.52], rtol=0, atol=.01)
np.testing.assert_allclose(t_pearsonr, [.035, .035], rtol=0, atol=0.005)

# T is binary, and can be used to check binary eval functions
sn5 = est.score_nuisances(Y=y, T=T, X=X, W=W, t_scoring='roc_auc')
np.testing.assert_allclose(sn5['T_roc_auc'], [0.526,0.526], rtol=0, atol=.005)

sn6 = est.score_nuisances(Y=y, T=T, X=X, W=W, t_scoring='log_loss')
np.testing.assert_allclose(sn6['T_log_loss'], [17.4,17.4], rtol=0, atol=0.1)

def test_aaforest_pandas(self):
"""Test that we can use CausalForest with pandas inputs."""
df = pd.DataFrame({'a': np.random.normal(size=500),
Expand Down
3 changes: 3 additions & 0 deletions econml/tests/test_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ def _gen_ortho_learner_model_final(self):
np.testing.assert_almost_equal(est.score_, sigma**2, decimal=2)
np.testing.assert_almost_equal(est.ortho_learner_model_final_.model.coef_[0], 1, decimal=2)

# Test that non-standard scoring raise the appropriate exception
self.assertRaises(NotImplementedError, est.score, y, X[:, 0], W=X[:, 1:], scoring='mean_squared_error')

@pytest.mark.ray
def test_ol_with_ray(self):
self._test_ol(True)
Expand Down
Loading