Skip to content

Commit

Permalink
Add pointwise error computation to method and add to forecasting unit…
Browse files Browse the repository at this point in the history
… test
  • Loading branch information
tom-andersson committed Oct 20, 2024
1 parent dcdab7a commit fe696a9
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 3 deletions.
1 change: 1 addition & 0 deletions deepsensor/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .metrics import *
24 changes: 24 additions & 0 deletions deepsensor/eval/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import xarray as xr
from deepsensor.model.pred import Prediction


def compute_errors(pred: Prediction, target: xr.Dataset) -> xr.Dataset:
"""
Compute errors between predictions and targets.
Args:
pred: Prediction object.
target: Target data.
Returns:
xr.Dataset: Dataset of pointwise differences between predictions and targets
at the same valid time in the predictions. Note, the difference is positive
when the prediction is greater than the target.
"""
errors = {}
for var_ID, pred_var in pred.items():
target_var = target[var_ID]
error = pred_var["mean"] - target_var.sel(time=pred_var.time)
error.name = f"{var_ID}"
errors[var_ID] = error
return xr.Dataset(errors)
13 changes: 10 additions & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from deepsensor.data.loader import TaskLoader
from deepsensor.model.convnp import ConvNP
from deepsensor.train.train import Trainer
from deepsensor.eval.metrics import compute_errors

from tests.utils import gen_random_data_xr, gen_random_data_pandas

Expand Down Expand Up @@ -686,9 +687,15 @@ def test_forecasting_model_predict_return_valid_times(self):

if isinstance(pred_var, xr.Dataset):
# Check we can compute errors using the valid time coord ('time')
errors = pred_var["mean"] - self.da.sel(time=pred_var.time)
assert errors.dims == ("lead_time", "init_time", "x1", "x2")
assert errors.shape == pred_var["mean"].shape
errors = compute_errors(pred, self.da.to_dataset())
for var_ID in errors.keys():
assert tuple(errors[var_ID].dims) == (
"lead_time",
"init_time",
"x1",
"x2",
)
assert errors[var_ID].shape == pred[var_ID]["mean"].shape
elif isinstance(pred_var, pd.DataFrame):
# Makes coordinate checking easier by avoiding repeat values
pred_var = pred_var.to_xarray().isel(x1=0, x2=0)
Expand Down

0 comments on commit fe696a9

Please sign in to comment.