From a1adf48935776eaf8ee571184318305e6e7977d7 Mon Sep 17 00:00:00 2001 From: Roni Kobrosly Date: Sun, 8 Sep 2024 14:16:17 -0400 Subject: [PATCH 1/6] added warning message and test for issue #837 --- econml/_ortho_learner.py | 9 +++++++++ econml/tests/test_dmliv.py | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/econml/_ortho_learner.py b/econml/_ortho_learner.py index b411314da..e84aff948 100644 --- a/econml/_ortho_learner.py +++ b/econml/_ortho_learner.py @@ -834,6 +834,15 @@ def _fit_nuisances(Y, T, X, W, Z, sample_weight, groups): elif not np.array_equal(fitted_inds, new_inds): raise AttributeError("Different indices were fit by different folds, so they cannot be aggregated") + if nuisances[1].sum() == 0: + raise ValueError( + """ + In fitting nuisances, the estimates for E[T|Z,X,W] are identical to E[T|X,W], + resulting in a situation where the rows will all be weighted to zero. Please + examine your instrument variable `Z`. + """ + ) + if self.mc_iters is not None: if self.mc_agg == 'mean': nuisances = tuple(np.mean(nuisance_mc_variants, axis=0) diff --git a/econml/tests/test_dmliv.py b/econml/tests/test_dmliv.py index 1fd491e22..0c9a56bed 100644 --- a/econml/tests/test_dmliv.py +++ b/econml/tests/test_dmliv.py @@ -254,3 +254,22 @@ def test_groups(self): est.fit(y, T, Z=Z, X=X, W=W, groups=groups) est.score(y, T, Z=Z, X=X, W=W) est.const_marginal_effect(X) + + def test_row_zero_weight_failure_mode(self): + np.random.seed(784) + + n = 100 + d_x = 3 + + Y = np.random.normal(size=(n,)) + T = np.random.normal(size=(n,)) + X = np.random.normal(size=(n, d_x)) + Z = np.random.normal(size=(n,)) + + est = NonParamDMLIV( + discrete_instrument=False, + discrete_treatment=False, + model_final=LinearRegression() + ) + with pytest.raises(ValueError, match=r" examine your instrument variable "): + est.fit(Y, T, Z=Z, X=X) From a3c6439fff655caab23f2ad7f91e043d52dab1be Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 8 Sep 2024 18:17:24 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- econml/tests/test_dmliv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/econml/tests/test_dmliv.py b/econml/tests/test_dmliv.py index 0c9a56bed..48d20b793 100644 --- a/econml/tests/test_dmliv.py +++ b/econml/tests/test_dmliv.py @@ -267,8 +267,8 @@ def test_row_zero_weight_failure_mode(self): Z = np.random.normal(size=(n,)) est = NonParamDMLIV( - discrete_instrument=False, - discrete_treatment=False, + discrete_instrument=False, + discrete_treatment=False, model_final=LinearRegression() ) with pytest.raises(ValueError, match=r" examine your instrument variable "): From ca78ec0646587d26960d68f47c17a9e8e8510863 Mon Sep 17 00:00:00 2001 From: Roni Kobrosly Date: Tue, 24 Sep 2024 20:18:07 -0400 Subject: [PATCH 3/6] Updated location of T_res error message Signed-off-by: Roni Kobrosly --- econml/_ortho_learner.py | 9 --------- econml/iv/dml/_dml.py | 8 ++++++++ econml/tests/test_dmliv.py | 21 +-------------------- 3 files changed, 9 insertions(+), 29 deletions(-) diff --git a/econml/_ortho_learner.py b/econml/_ortho_learner.py index e84aff948..b411314da 100644 --- a/econml/_ortho_learner.py +++ b/econml/_ortho_learner.py @@ -834,15 +834,6 @@ def _fit_nuisances(Y, T, X, W, Z, sample_weight, groups): elif not np.array_equal(fitted_inds, new_inds): raise AttributeError("Different indices were fit by different folds, so they cannot be aggregated") - if nuisances[1].sum() == 0: - raise ValueError( - """ - In fitting nuisances, the estimates for E[T|Z,X,W] are identical to E[T|X,W], - resulting in a situation where the rows will all be weighted to zero. Please - examine your instrument variable `Z`. - """ - ) - if self.mc_iters is not None: if self.mc_agg == 'mean': nuisances = tuple(np.mean(nuisance_mc_variants, axis=0) diff --git a/econml/iv/dml/_dml.py b/econml/iv/dml/_dml.py index 741162147..e76aa1de4 100644 --- a/econml/iv/dml/_dml.py +++ b/econml/iv/dml/_dml.py @@ -761,6 +761,14 @@ def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None) TX_pred = np.tile(TX_pred.reshape(1, -1), (T.shape[0], 1)) Y_res = Y - Y_pred.reshape(Y.shape) T_res = TXZ_pred.reshape(T.shape) - TX_pred.reshape(T.shape) + if T_res.sum() == 0: + raise ValueError( + """ + All values of the treatment residual are 0, + which then makes them unsuitable to use as weights + in downstream in econml/dml/dml.py + """ + ) return Y_res, T_res diff --git a/econml/tests/test_dmliv.py b/econml/tests/test_dmliv.py index 48d20b793..f3ac0c281 100644 --- a/econml/tests/test_dmliv.py +++ b/econml/tests/test_dmliv.py @@ -253,23 +253,4 @@ def test_groups(self): with self.subTest(est=est): est.fit(y, T, Z=Z, X=X, W=W, groups=groups) est.score(y, T, Z=Z, X=X, W=W) - est.const_marginal_effect(X) - - def test_row_zero_weight_failure_mode(self): - np.random.seed(784) - - n = 100 - d_x = 3 - - Y = np.random.normal(size=(n,)) - T = np.random.normal(size=(n,)) - X = np.random.normal(size=(n, d_x)) - Z = np.random.normal(size=(n,)) - - est = NonParamDMLIV( - discrete_instrument=False, - discrete_treatment=False, - model_final=LinearRegression() - ) - with pytest.raises(ValueError, match=r" examine your instrument variable "): - est.fit(Y, T, Z=Z, X=X) + est.const_marginal_effect(X) \ No newline at end of file From 1df31c2033516185c99c369373d277975663d596 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Sep 2024 00:19:34 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- econml/tests/test_dmliv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/econml/tests/test_dmliv.py b/econml/tests/test_dmliv.py index f3ac0c281..1fd491e22 100644 --- a/econml/tests/test_dmliv.py +++ b/econml/tests/test_dmliv.py @@ -253,4 +253,4 @@ def test_groups(self): with self.subTest(est=est): est.fit(y, T, Z=Z, X=X, W=W, groups=groups) est.score(y, T, Z=Z, X=X, W=W) - est.const_marginal_effect(X) \ No newline at end of file + est.const_marginal_effect(X) From 70cc2dd45daf9ea4dc5cf3d17a647aba770dc743 Mon Sep 17 00:00:00 2001 From: Roni Kobrosly Date: Fri, 27 Sep 2024 16:18:29 -0400 Subject: [PATCH 5/6] Update econml/iv/dml/_dml.py Co-authored-by: Keith Battocchi Signed-off-by: Roni Kobrosly --- econml/iv/dml/_dml.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/econml/iv/dml/_dml.py b/econml/iv/dml/_dml.py index e76aa1de4..835bbb522 100644 --- a/econml/iv/dml/_dml.py +++ b/econml/iv/dml/_dml.py @@ -764,9 +764,14 @@ def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None) if T_res.sum() == 0: raise ValueError( """ - All values of the treatment residual are 0, - which then makes them unsuitable to use as weights - in downstream in econml/dml/dml.py + All values of the treatment residual are 0, + which then makes them unsuitable to use as weights. + DRIV requires that the instrument Z has an effect on the + expected treatment value of at least some rows. + + If you are using regularized models, it's possible that this error is a + result of regularizing too strongly, so that all predictions from both + models are constant. """ ) return Y_res, T_res From f013ed589034d84a6ffb0d4deccfb95fa57c4c32 Mon Sep 17 00:00:00 2001 From: Roni Kobrosly Date: Fri, 27 Sep 2024 16:18:38 -0400 Subject: [PATCH 6/6] Update econml/iv/dml/_dml.py Co-authored-by: Keith Battocchi Signed-off-by: Roni Kobrosly --- econml/iv/dml/_dml.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/econml/iv/dml/_dml.py b/econml/iv/dml/_dml.py index 835bbb522..0c0b90655 100644 --- a/econml/iv/dml/_dml.py +++ b/econml/iv/dml/_dml.py @@ -761,7 +761,7 @@ def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None) TX_pred = np.tile(TX_pred.reshape(1, -1), (T.shape[0], 1)) Y_res = Y - Y_pred.reshape(Y.shape) T_res = TXZ_pred.reshape(T.shape) - TX_pred.reshape(T.shape) - if T_res.sum() == 0: + if not T_res.any(): raise ValueError( """ All values of the treatment residual are 0,