Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Estimator return types #264

Merged
merged 40 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0eefd4b
Patsy interaction treatment terms can now be read
jmafoster1 Feb 7, 2024
c255ce0
refactor estimate_coefficent to only return pd.Series
christopher-wild Feb 13, 2024
4736391
Adapt unit tests to access series values for coefficents
christopher-wild Feb 13, 2024
1eac357
Fetch factors from Patsy to check types
christopher-wild Feb 13, 2024
f616091
Return float rather than Series for ci_high and ci_low
christopher-wild Feb 13, 2024
1e8ad89
Handle float and series confidence intervals
christopher-wild Feb 14, 2024
abe8599
Handle correct exception
christopher-wild Feb 14, 2024
767e5b4
More flexible handling due to multiple float types returns by estimators
christopher-wild Feb 14, 2024
2724d94
Handle series values
christopher-wild Feb 14, 2024
da87421
Linting
christopher-wild Feb 14, 2024
aca8bf9
Merge branch 'main' into interaction-terms
christopher-wild Feb 14, 2024
ea8c273
refactor all estimate_* return types to be pd.Series for LinearRegres…
christopher-wild Feb 14, 2024
c777503
Update return typings
christopher-wild Feb 14, 2024
fd7f79d
Refactor other estimator classes to return pd.Series
christopher-wild Feb 14, 2024
8f9499d
Extract bool from series
christopher-wild Feb 14, 2024
3bcefc7
Remove gen expression so elements can be indexed
christopher-wild Feb 14, 2024
e43bf38
All effects now expect pd.Series for the test values
christopher-wild Feb 14, 2024
ace2612
Update all unit tests to work with pd.Series refactor
christopher-wild Feb 14, 2024
c942537
Merge remote-tracking branch 'origin/interaction-terms' into interact…
christopher-wild Feb 14, 2024
18883d8
example_poisson_process.py now works with pd.Series refactor
christopher-wild Feb 14, 2024
8590dc9
Update typing
christopher-wild Feb 23, 2024
8e33b25
Update tests to use pd.Series for confidence intervals
christopher-wild Feb 23, 2024
9641319
Dictionary assertions use list CIs
christopher-wild Feb 23, 2024
35bae1f
SomeEffect and NoneEffect applys now work with pd.Series
christopher-wild Feb 23, 2024
6a5987a
_get_confidence_intervals method returns pd.Series
christopher-wild Feb 23, 2024
d0322ed
Merge branch 'main' into interaction-terms
christopher-wild Feb 23, 2024
33e7e53
Remove unnecessary unpacking of value
christopher-wild Feb 23, 2024
123d4db
tests represent the logic of returning Series better
christopher-wild Feb 23, 2024
17b8692
Pylint suggestions
christopher-wild Feb 23, 2024
4e06f34
Update surrogate code for new series return vals
rsomers1998 Feb 26, 2024
f66f854
Merge branch 'interaction-terms' of https://github.com/CITCOM-project…
rsomers1998 Feb 26, 2024
d742f74
Remove unused import
christopher-wild Feb 27, 2024
425329b
Update LR91 examples
christopher-wild Feb 27, 2024
c026f6a
Update example_beta.py
christopher-wild Feb 27, 2024
eb6bca6
Raise exception for Positive and Negative effect if multiple values p…
christopher-wild Feb 27, 2024
5265e9f
Fix typo in check for value length
christopher-wild Feb 27, 2024
fb287ec
Add test for catching multiple value exception
christopher-wild Feb 27, 2024
844849a
Use pandas inbuilt assert_series_equal test instead of casting everyt…
christopher-wild Feb 27, 2024
6029edb
Add limitation of single test_value to Effect docstrings
christopher-wild Feb 28, 2024
b8ad419
Ensure only single ate values are provided in surrogate_models
christopher-wild Feb 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion causal_testing/testing/causal_test_outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable
import numpy as np
import pandas as pd

from causal_testing.testing.causal_test_result import CausalTestResult

Expand Down Expand Up @@ -57,7 +58,7 @@ def apply(self, res: CausalTestResult) -> bool:
ci_low = res.ci_low() if isinstance(res.ci_low(), Iterable) else [res.ci_low()]
ci_high = res.ci_high() if isinstance(res.ci_high(), Iterable) else [res.ci_high()]
value = res.test_value.value if isinstance(res.ci_high(), Iterable) else [res.test_value.value]

value = value[0] if isinstance(value[0], pd.Series) else value
christopher-wild marked this conversation as resolved.
Show resolved Hide resolved
return (
sum(
not ((ci_low < 0 < ci_high) or abs(v) < self.atol)
Expand Down
4 changes: 4 additions & 0 deletions causal_testing/testing/causal_test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,16 @@ def to_dict(self, json=False):
def ci_low(self):
"""Return the lower bracket of the confidence intervals."""
if self.confidence_intervals:
if isinstance(self.confidence_intervals[0], pd.Series):
return self.confidence_intervals[0][0]
return self.confidence_intervals[0]
return None

def ci_high(self):
"""Return the higher bracket of the confidence intervals."""
if self.confidence_intervals:
if isinstance(self.confidence_intervals[1], pd.Series):
return self.confidence_intervals[1][0]
return self.confidence_intervals[1]
return None

Expand Down
12 changes: 5 additions & 7 deletions causal_testing/testing/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import statsmodels.formula.api as smf
from econml.dml import CausalForestDML
from patsy import dmatrix # pylint: disable = no-name-in-module

from patsy import ModelDesc
from sklearn.ensemble import GradientBoostingRegressor
from statsmodels.regression.linear_model import RegressionResultsWrapper
from statsmodels.tools.sm_exceptions import PerfectSeparationError
Expand Down Expand Up @@ -351,19 +351,17 @@ def estimate_coefficient(self) -> float:
"""
model = self._run_linear_regression()
newline = "\n"
treatment = [self.treatment]
if str(self.df.dtypes[self.treatment]) == "object":
patsy_md = ModelDesc.from_formula(self.treatment)
if any((self.df.dtypes[factor.name()] == 'object' for factor in patsy_md.rhs_termlist[1].factors)):
design_info = dmatrix(self.formula.split("~")[1], self.df).design_info
treatment = design_info.column_names[design_info.term_name_slices[self.treatment]]
else:
treatment = [self.treatment]
assert set(treatment).issubset(
model.params.index.tolist()
), f"{treatment} not in\n{' ' + str(model.params.index).replace(newline, newline + ' ')}"
unit_effect = model.params[treatment] # Unit effect is the coefficient of the treatment
[ci_low, ci_high] = self._get_confidence_intervals(model, treatment)
if str(self.df.dtypes[self.treatment]) != "object":
unit_effect = unit_effect[0]
ci_low = ci_low[0]
ci_high = ci_high[0]
return unit_effect, [ci_low, ci_high]

def estimate_ate(self) -> tuple[float, list[float, float], float]:
Expand Down
8 changes: 4 additions & 4 deletions tests/testing_tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def test_program_11_2(self):
self.assertEqual(round(model.params["Intercept"] + 90 * model.params["treatments"], 1), 216.9)

# Increasing treatments from 90 to 100 should be the same as 10 times the unit ATE
self.assertEqual(round(model.params["treatments"], 1), round(ate, 1))
self.assertEqual(round(model.params["treatments"], 1), round(ate[0], 1))
christopher-wild marked this conversation as resolved.
Show resolved Hide resolved

def test_program_11_3(self):
"""Test whether our linear regression implementation produces the same results as program 11.3 (p. 144)."""
Expand All @@ -251,7 +251,7 @@ def test_program_11_3(self):
197.1,
)
# Increasing treatments from 90 to 100 should be the same as 10 times the unit ATE
self.assertEqual(round(model.params["treatments"], 3), round(ate, 3))
self.assertEqual(round(model.params["treatments"], 3), round(ate[0], 3))

def test_program_15_1A(self):
"""Test whether our linear regression implementation produces the same results as program 15.1 (p. 163, 184)."""
Expand Down Expand Up @@ -329,8 +329,8 @@ def test_program_15_no_interaction(self):
# terms_to_square = ["age", "wt71", "smokeintensity", "smokeyrs"]
# for term_to_square in terms_to_square:
ate, [ci_low, ci_high] = linear_regression_estimator.estimate_coefficient()
self.assertEqual(round(ate, 1), 3.5)
self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [2.6, 4.3])
self.assertEqual(round(ate[0], 1), 3.5)
self.assertEqual([round(ci_low[0], 1), round(ci_high[0], 1)], [2.6, 4.3])

def test_program_15_no_interaction_ate(self):
"""Test whether our linear regression implementation produces the same results as program 15.1 (p. 163, 184)
Expand Down
Loading