Skip to content

Commit

Permalink
add functionality to cs config to read TMD file
Browse files Browse the repository at this point in the history
  • Loading branch information
jdebacker committed May 30, 2024
1 parent c993fbc commit 778ea30
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 2 deletions.
28 changes: 27 additions & 1 deletion cs-config/cs_config/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
postprocess,
nth_year_results,
retrieve_puf,
retrieve_tmd,
)
from .outputs import create_layout, aggregate_plot
from taxbrain import TaxBrain, report
Expand All @@ -25,6 +26,9 @@
PUF_S3_FILE_LOCATION = os.environ.get(
"PUF_S3_LOCATION", "s3://ospc-data-files/puf.20210720.csv.gz"
)
TMD_S3_FILE_LOCATION = os.environ.get(
"TMD_S3_LOCATION", "s3://ospc-data-files/tmd.20210720.csv.gz"
)

CUR_PATH = os.path.abspath(os.path.dirname(__file__))

Expand Down Expand Up @@ -125,7 +129,24 @@ def run_model(meta_params_dict, adjustment):
# Access keys are not available. Default to the CPS.
print("Defaulting to the CPS")
meta_params.adjust({"data_source": "CPS"})
if meta_params.data_source == "CPS":
elif meta_params.data_source == "TMD":
tmd_df = retrieve_tmd(
TMD_S3_FILE_LOCATION, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY
)
if tmd_df is not None:
if not isinstance(tmd_df, pd.DataFrame):
raise TypeError("'tmd_df' must be a Pandas DataFrame.")
fuzz = True
sampling_frac = 0.05
sampling_seed = 2222
full_sample = tmd_df
data_start_year = taxcalc.Records.TMDCSV_YEAR
weights = taxcalc.Records.TMD_WEIGHTS_FILENAME
else:
# Access keys are not available. Default to the CPS.
print("Defaulting to the CPS")
meta_params.adjust({"data_source": "CPS"})
elif meta_params.data_source == "CPS":
fuzz = False
input_path = os.path.join(TCDIR, "cps.csv.gz")
# full_sample = read_egg_csv(cpspath) # pragma: no cover
Expand All @@ -134,6 +155,11 @@ def run_model(meta_params_dict, adjustment):
full_sample = pd.read_csv(input_path)
data_start_year = taxcalc.Records.CPSCSV_YEAR
weights = taxcalc.Records.CPS_WEIGHTS_FILENAME
else:
raise ValueError(
f"Data source '{meta_params.data_source}' is not supported."
)

if meta_params.use_full_sample:
sample = full_sample
end_year = min(start_year + 10, TaxBrain.LAST_BUDGET_YEAR)
Expand Down
43 changes: 42 additions & 1 deletion cs-config/cs_config/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@
"PUF_S3_LOCATION", "s3://ospc-data-files/puf.20210720.csv.gz"
)

TMD_S3_FILE_LOCATION = os.environ.get(
"TMD_S3_LOCATION", "s3://ospc-data-files/tmd.20210720.csv.gz"
)


def random_seed(user_mods, year):
"""
Expand Down Expand Up @@ -376,7 +380,7 @@ def retrieve_puf(
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
):
"""
Function for retrieving the PUF from the OSPC S3 bucket
Function for retrieving the PUF from the S3 bucket
"""
s3_reader_installed = S3FileSystem is not None
has_credentials = (
Expand Down Expand Up @@ -405,3 +409,40 @@ def retrieve_puf(
f"s3_reader_installed={s3_reader_installed})"
)
return None


def retrieve_tmd(
tmd_s3_file_location=TMD_S3_FILE_LOCATION,
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
):
"""
Function for retrieving the TMD from the S3 bucket
"""
s3_reader_installed = S3FileSystem is not None
has_credentials = (
aws_access_key_id is not None and aws_secret_access_key is not None
)
if tmd_s3_file_location and has_credentials and s3_reader_installed:
print("Reading tmd from S3 bucket.", tmd_s3_file_location)
fs = S3FileSystem(
key=AWS_ACCESS_KEY_ID,
secret=AWS_SECRET_ACCESS_KEY,
)
with fs.open(tmd_s3_file_location) as f:
# Skips over header from top of file.
tmd_df = pd.read_csv(f)
return tmd_df
elif Path("tmd.csv.gz").exists():
print("Reading tmd from tmd.csv.gz.")
return pd.read_csv("tmd.csv.gz", compression="gzip")
elif Path("tmd.csv").exists():
print("Reading tmd from tmd.csv.")
return pd.read_csv("tmd.csv")
else:
warnings.warn(
f"TMD file not available (tmd_location={tmd_s3_file_location}, "
f"has_credentials={has_credentials}, "
f"s3_reader_installed={s3_reader_installed})"
)
return None

0 comments on commit 778ea30

Please sign in to comment.