Skip to content

Commit

Permalink
Merge pull request #3628 from catalyst-cooperative/tweak-eia860m-sett…
Browse files Browse the repository at this point in the history
…ings

Tweak eia860m settings validations and tests.
  • Loading branch information
aesharpe authored May 13, 2024
2 parents 8bd4e96 + 69c0449 commit 868ccdb
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/pudl/extract/eia860m.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def raw_eia860m__all_dfs(context):

eia860m_extractor = Extractor(ds=ds)
raw_eia860m__all_dfs = eia860m_extractor.extract(
year_month=eia_settings.eia860m.year_month
year_month=eia_settings.eia860m.year_months
)
return raw_eia860m__all_dfs

Expand Down
35 changes: 17 additions & 18 deletions src/pudl/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
field_validator,
model_validator,
)
from pydantic_core.core_schema import FieldValidationInfo
from pydantic_settings import BaseSettings

import pudl
Expand Down Expand Up @@ -257,7 +256,7 @@ class Eia860Settings(GenericDatasetSettings):

@field_validator("eia860m_year_months")
@classmethod
def add_other_860m_years(cls, v, info: FieldValidationInfo) -> list[str]:
def add_other_860m_years(cls, v, info: ValidationInfo) -> list[str]:
"""Find extra years from EIA860m if applicable.
There's a gap in reporting (after the new year but before the EIA 860 early
Expand All @@ -266,45 +265,44 @@ def add_other_860m_years(cls, v, info: FieldValidationInfo) -> list[str]:
year of 860m data that is not yet available in 860.
"""
extra_eia860m_year_months = []
if info.data["eia860m"]:
all_eia860m_years = {
date.split("-")[0] for date in info.data["all_eia860m_year_months"]
}
all_eia860m_years = set(
pd.to_datetime(info.data["all_eia860m_year_months"]).year
)
# The years in 860m that are not in 860
extra_eia860m_years = {
year
for year in all_eia860m_years
if int(year) not in info.data["years"]
year for year in all_eia860m_years if year not in info.data["years"]
}
# The years already listed as variables in eia860m_year_months
years_in_v = {date.split("-")[0] for date in v}
years_in_v = set(pd.to_datetime(v).year)
# The max year_month values available in 860m for each year not
# covered by EIA860 (and not already listed in the eia860m_year_months
# variable)
extra_eia860m_year_months = [
max(
date
for date in info.data["all_eia860m_year_months"]
if date.startswith(year)
for date in pd.to_datetime(info.data["all_eia860m_year_months"])
if date.year == year
)
for year in (extra_eia860m_years - years_in_v)
if year > max(info.data["years"])
]
return v + extra_eia860m_year_months
return v
return v + list(pd.Series(extra_eia860m_year_months).dt.strftime("%Y-%m"))

@field_validator("eia860m_year_months")
@classmethod
def no_repeat_years(cls, v, info: FieldValidationInfo) -> list[str]:
def no_repeat_years(cls, v, info: ValidationInfo) -> list[str]:
"""Make sure there are no duplicate 860m year values."""
if info.data["eia860m"]:
years_in_v = [date.split("-")[0] for date in v]
years_in_v = pd.to_datetime(v).year
if len(years_in_v) != len(set(years_in_v)):
raise ValueError(f"{v} contains duplicate year values.")
return v

@field_validator("eia860m_year_months")
@classmethod
def eia860_variable_values_exist(cls, v, info: FieldValidationInfo) -> list[str]:
def eia860_variable_values_exist(cls, v, info: ValidationInfo) -> list[str]:
"""Check that the year_month values for eia860m_year_months exist."""
if info.data["eia860m"]:
for year_month in v:
Expand All @@ -314,12 +312,13 @@ def eia860_variable_values_exist(cls, v, info: FieldValidationInfo) -> list[str]

@field_validator("eia860m_year_months")
@classmethod
def only_years_not_in_eia860(cls, v, info: FieldValidationInfo) -> list[str]:
def only_years_not_in_eia860(cls, v, info: ValidationInfo) -> list[str]:
"""Ensure no EIA860m values are from years already in EIA860."""
if info.data["eia860m"]:
for year in {date.split("-")[0] for date in v}:
for year in pd.to_datetime(v).year.unique():
if year in info.data["years"]:
raise ValueError(f"EIA860m year {year} available in EIA860")
return v


class Eia860mSettings(GenericDatasetSettings):
Expand Down
29 changes: 18 additions & 11 deletions test/unit/settings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ def test_missing_field_error(self: Self):
working_tables = ["table"]

class Test(GenericDatasetSettings):
data_source: DataSource(
working_partitions=working_partitions, working_tables=working_tables
data_source: DataSource = DataSource(
working_partitions=working_partitions,
working_tables=working_tables,
)

Test()
Expand Down Expand Up @@ -125,18 +126,24 @@ def test_none_quarters_raise(self: Self):
_ = EpaCemsSettings(quarters=None)


class TestEIA860Settings:
class TestEia860Settings:
"""Test EIA860 setting validation."""

def test_860m(self: Self):
"""Test validation error is raised when eia860m date is within 860 years."""
settings_cls = Eia860Settings()
original_eia80m_year_month = settings_cls.eia860m_year_months
settings_cls.eia860m_year_months = ["2019-11"]
def test_eia860_years_overlap_eia860m_years(self: Self):
"""Test validation error is raised when eia860m date is within eia860 years."""
# Identify the last valid EIA-860 year:
max_eia860_year = max(Eia860Settings().years)
# Use that year to construct an EIA-860M year that overlaps the EIA-860 years:
bad_eia860m_year_month = f"{max_eia860_year}-01"

with pytest.raises(ValueError):
settings_cls(eia860m=True)
settings_cls.eia860m_year_months = original_eia80m_year_month
# Attempt to construct an EIA-860 settings object with an EIA-860M year that
# overlaps the EIA-860 years, which should result in a ValidationError:
with pytest.raises(ValidationError):
_ = Eia860Settings(
eia860m=True,
years=[max_eia860_year],
eia860m_year_months=[bad_eia860m_year_month],
)


class TestEia860mSettings:
Expand Down

0 comments on commit 868ccdb

Please sign in to comment.