Skip to content

Commit

Permalink
Add logic so that _ortho_learner raises an exception if non-standard …
Browse files Browse the repository at this point in the history
…scoring is used without an _rlearner._ModelFinal

Test that logic in the _otho_learner test
  • Loading branch information
carl-offerfit committed Feb 20, 2025
1 parent 3abd752 commit 39cc39d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
20 changes: 17 additions & 3 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,10 +1116,24 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, s

accumulated_nuisances += nuisances

score_kwargs = {
'X': X,
'W': W,
'Z': Z,
'sample_weight': sample_weight,
'groups': groups
}
# If using an _rlearner, the scoring parameter can be passed along, if provided
if scoring is not None:
# Cannot import in header, or circular imports
from .dml._rlearner import _ModelFinal
if isinstance(self._ortho_learner_model_final, _ModelFinal):
score_kwargs['scoring'] = scoring
else:
raise NotImplementedError("scoring parameter only implemented for "
"_rlearner._ModelFinal")
return self._ortho_learner_model_final.score(Y, T, nuisances=accumulated_nuisances,
**filter_none_kwargs(X=X, W=W, Z=Z,
sample_weight=sample_weight,
groups=groups, scoring=scoring))
**filter_none_kwargs(**score_kwargs))

@property
def ortho_learner_model_final_(self):
Expand Down
3 changes: 3 additions & 0 deletions econml/tests/test_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ def _gen_ortho_learner_model_final(self):
np.testing.assert_almost_equal(est.score_, sigma**2, decimal=2)
np.testing.assert_almost_equal(est.ortho_learner_model_final_.model.coef_[0], 1, decimal=2)

# Test that non-standard scoring raise the appropriate exception
self.assertRaises(NotImplementedError, est.score, y, X[:, 0], W=X[:, 1:], scoring='mean_squared_error')

@pytest.mark.ray
def test_ol_with_ray(self):
self._test_ol(True)
Expand Down

0 comments on commit 39cc39d

Please sign in to comment.