Skip to content

Commit

Permalink
Adjust windows so that method=resample and method=rolling align. …
Browse files Browse the repository at this point in the history
…Also set `skipna=False` when computing statistics. This is anticipation of the following development workflow:

* Make '1w' rolling window version of ground truth
  * Compute '1w' rolling window version of 0.25degree ERA5, 1990-2023, starting from 1h spaced ERA
  * Regrid and compute derived variables
  * Compute climatology
* Make resampled forecasts
  * Forecast to 6-16 week lead times, with 12hr timedelta spacing, and init time spacing that is an odd multiple of 12h.
  * resample forecasts to weekly and delete (ttl) original forecasts to save space
* Evaluate resampled forecasts vs. rolled ground truth

Using `rolling` for ground truth is better than `resample`. If we used `resample`, then we'd have to be sure our forecast start times and timedeltas were aligned with the window used for ground truth

Clearly we could also roll the forecasts, but that barely increases statistical power and wastes time & space. For final WB evals, that may make sense, but for development we want something faster and more lightweight.

PiperOrigin-RevId: 631443536
  • Loading branch information
langmore authored and Weatherbench2 authors committed May 7, 2024
1 parent 9064d70 commit 1f0a1d6
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 24 deletions.
35 changes: 29 additions & 6 deletions scripts/resample_in_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@
'method',
'resample',
['resample', 'rolling'],
help='Whether to resample to new times, or use a rolling window.',
help=(
'Whether to resample to new times (spaced by --period), or use a'
' rolling window. In either case, output at time index T uses the'
' window [T, T + period]. In particular, whether using resample or'
' rolling, output at matching times will be the same.'
),
)
PERIOD = flags.DEFINE_string(
'period',
Expand Down Expand Up @@ -234,21 +239,25 @@ def resample_in_time_core(
f'{delta_t=} between chunk times did not evenly divide {period=}'
)
return getattr(
chunk.rolling({TIME_DIM.value: period // delta_t}),
chunk.rolling(
{TIME_DIM.value: period // delta_t}, center=False, min_periods=None
),
statistic,
)()
)(skipna=False)
elif method == 'resample':
return getattr(
chunk.resample({TIME_DIM.value: period}),
chunk.resample({TIME_DIM.value: period}, label='left'),
statistic,
)()
)(skipna=False)
else:
raise ValueError(f'Unhandled {method=}')


def main(argv: abc.Sequence[str]) -> None:

ds, input_chunks = xbeam.open_zarr(INPUT_PATH.value)
period = pd.to_timedelta(PERIOD.value)

if TIME_START.value is not None or TIME_STOP.value is not None:
ds = ds.sel({TIME_DIM.value: slice(TIME_START.value, TIME_STOP.value)})

Expand All @@ -267,8 +276,22 @@ def main(argv: abc.Sequence[str]) -> None:
)
ds = ds[keep_vars]

# To ensure results at time T use data from [T, T + period], an offset needs
# to be added if the method is rolling.
# It would be wonderful if this was the default, or possible with appropriate
# kwargs in rolling, but alas...
if METHOD.value == 'rolling':
delta_ts = pd.to_timedelta(np.unique(np.diff(ds[TIME_DIM.value].data)))
if len(delta_ts) != 1:
raise ValueError(
f'Input data must have constant spacing. Found {delta_ts}'
)
delta_t = delta_ts[0]
ds = ds.assign_coords(
{TIME_DIM.value: ds[TIME_DIM.value] - period + delta_t}
)

# Make the template
period = pd.to_timedelta(PERIOD.value)
if METHOD.value == 'resample':
rsmp_times = resample_in_time_core(
# All stats will give the same times, so use 'mean' arbitrarily.
Expand Down
150 changes: 132 additions & 18 deletions scripts/resample_in_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@

class ResampleInTimeTest(parameterized.TestCase):

def test_demonstrating_returned_times_for_resample(self):
@parameterized.named_parameters(
dict(testcase_name='NoNaN', insert_nan=False),
dict(testcase_name='YesNaN', insert_nan=True),
)
def test_demonstrating_resample_and_rolling_are_aligned(self, insert_nan):
# times = 10 days, starting at Jan 1
times = pd.DatetimeIndex(
[
Expand All @@ -43,7 +47,14 @@ def test_demonstrating_returned_times_for_resample(self):
'2023-01-10',
]
)
temperatures = np.arange(len(times))
temperatures = np.arange(len(times)).astype(float)

if insert_nan:
# NaN inserted to (i) verify skipna=False, and (ii) verify correct setting
# for min_periods. If e.g. min_periods=1, then NaN values get skipped so
# long as there is at least one non-NaN value!
temperatures[0] = np.nan

input_ds = xr.Dataset(
{
'temperature': xr.DataArray(
Expand All @@ -53,37 +64,128 @@ def test_demonstrating_returned_times_for_resample(self):
)

input_path = self.create_tempdir('source').full_path
output_path = self.create_tempdir('destination').full_path

input_ds.to_zarr(input_path)

# Get resampled output
resample_output_path = self.create_tempdir('resample').full_path
with flagsaver.as_parsed(
input_path=input_path,
output_path=output_path,
output_path=resample_output_path,
method='resample',
period='1w',
period='3d',
mean_vars='ALL',
runner='DirectRunner',
):
resample_in_time.main([])
resample, unused_output_chunks = xarray_beam.open_zarr(resample_output_path)

output_ds, unused_output_chunks = xarray_beam.open_zarr(output_path)
# Show that the output at time T uses data from the window [T, T + period]
np.testing.assert_array_equal(
output_ds.time,
# The first output time is the first input time
# The second output time is the first + 1w
np.array(
['2023-01-01T00:00:00.000000000', '2023-01-08T00:00:00.000000000'],
dtype='datetime64[ns]',
pd.to_datetime(resample.time),
pd.DatetimeIndex(
['2023-01-01', '2023-01-04', '2023-01-07', '2023-01-10']
),
)

np.testing.assert_array_equal(
output_ds.temperature.data,
# The first temperature is the average of the first 7 times
# The second temperature is the average of the remaining times (of which
# there are only 3).
[np.mean(temperatures[:7]), np.mean(temperatures[7:14])],
resample.temperature.data,
[
np.mean(temperatures[:3]), # Will be NaN if `insert_nan`
np.mean(temperatures[3:6]),
np.mean(temperatures[6:9]),
np.mean(temperatures[9:12]),
],
)

# Get rolled output
rolling_output_path = self.create_tempdir('rolling').full_path
with flagsaver.as_parsed(
input_path=input_path,
output_path=rolling_output_path,
method='rolling',
period='3d',
mean_vars='ALL',
runner='DirectRunner',
):
resample_in_time.main([])
rolling, unused_output_chunks = xarray_beam.open_zarr(rolling_output_path)

common_times = pd.DatetimeIndex(['2023-01-01', '2023-01-04', '2023-01-07'])
xr.testing.assert_equal(
resample.sel(time=common_times),
rolling.sel(time=common_times),
)

@parameterized.parameters(
(20, '3d', None),
(21, '3d', None),
(21, '8d', None),
(5, '1d', None),
(20, '3d', [0, 4, 8]),
(21, '3d', [20]),
(21, '8d', [15]),
)
def test_demonstrating_resample_and_rolling_are_aligned_many_combinations(
self,
n_times,
period,
nan_locations,
):
# Less readable than test_demonstrating_resample_and_rolling_are_aligned,
# but these sorts of automated checks ensure we didn't miss an edge case
# (there are many!!!!)
times = pd.date_range('2010', periods=n_times)
temperatures = np.random.RandomState(802701).rand(n_times)

for i in nan_locations or []:
temperatures[i] = np.nan

input_ds = xr.Dataset(
{
'temperature': xr.DataArray(
temperatures, coords=[times], dims=['time']
)
}
)

input_path = self.create_tempdir('source').full_path
input_ds.to_zarr(input_path)

# Get resampled output
resample_output_path = self.create_tempdir('resample').full_path
with flagsaver.as_parsed(
input_path=input_path,
output_path=resample_output_path,
method='resample',
period=period,
mean_vars='ALL',
runner='DirectRunner',
):
resample_in_time.main([])
resample, unused_output_chunks = xarray_beam.open_zarr(resample_output_path)

# Get rolled output
rolling_output_path = self.create_tempdir('rolling').full_path
with flagsaver.as_parsed(
input_path=input_path,
output_path=rolling_output_path,
method='rolling',
period=period,
mean_vars='ALL',
runner='DirectRunner',
):
resample_in_time.main([])
rolling, unused_output_chunks = xarray_beam.open_zarr(rolling_output_path)

common_times = pd.to_datetime(resample.time.data).intersection(
rolling.time.data
)

# At most, one time is lost if the period doesn't evenly divide n_times.
self.assertGreaterEqual(len(common_times), len(resample.time) - 1)
xr.testing.assert_equal(
resample.sel(time=common_times),
rolling.sel(time=common_times),
)

@parameterized.named_parameters(
Expand Down Expand Up @@ -190,6 +292,12 @@ def test_resample_time(self, method, add_mean_suffix, period):
// pd.to_timedelta(input_time_resolution)
).mean()
)
# Enact the time offsetting needed to align resample and rolling.
expected_mean = expected_mean.assign_coords(
time=expected_mean.time
- pd.to_timedelta(period)
+ pd.to_timedelta(input_time_resolution)
)
else:
raise ValueError(f'Unhandled {method=}')

Expand Down Expand Up @@ -337,6 +445,12 @@ def test_resample_prediction_timedelta(self, method, add_mean_suffix, period):
// pd.to_timedelta(input_time_resolution)
).mean()
)
# Enact the time offsetting needed to align resample and rolling.
expected_mean = expected_mean.assign_coords(
prediction_timedelta=expected_mean.prediction_timedelta
- pd.to_timedelta(period)
+ pd.to_timedelta(input_time_resolution)
)
else:
raise ValueError(f'Unhandled {method=}')

Expand Down

0 comments on commit 1f0a1d6

Please sign in to comment.