Skip to content

Commit

Permalink
Add off-grid non-Gaussian unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
tom-andersson committed Dec 10, 2023
1 parent 33b8c2b commit 90bec68
Showing 1 changed file with 33 additions and 2 deletions.
35 changes: 33 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,10 +437,41 @@ def test_highlevel_predict_coords_align_with_X_t_offgrid(self):
df_raw.reset_index()["longitude"],
)

def test_highlevel_predict_with_pred_params(self):
def test_highlevel_predict_with_pred_params_pandas(self):
"""
Test that passing ``pred_params`` to ``.predict`` works with
a spikes-beta likelihood.
a spikes-beta likelihood for prediction to pandas.
"""
tl = TaskLoader(context=self.da, target=self.da)
model = ConvNP(
self.dp,
tl,
unet_channels=(5, 5, 5),
verbose=False,
likelihood="cnp-spikes-beta",
)
task = tl("2020-01-01", context_sampling=10, target_sampling=10)

# Off-grid prediction
X_t = np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]])

# Check that nothing breaks and the correct parameters are returned
pred_params = ["mean", "std", "variance", "alpha", "beta"]
pred = model.predict(task, X_t=X_t, pred_params=pred_params)
for pred_param in pred_params:
assert pred_param in pred["var"]

# Test mixture probs special case
pred_params = ["mixture_probs"]
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
for component in range(model.N_mixture_components):
pred_param = f"mixture_probs_{component}"
assert pred_param in pred["var"]

def test_highlevel_predict_with_pred_params_xarray(self):
"""
Test that passing ``pred_params`` to ``.predict`` works with
a spikes-beta likelihood for prediction to xarray.
"""
tl = TaskLoader(context=self.da, target=self.da)
model = ConvNP(
Expand Down

0 comments on commit 90bec68

Please sign in to comment.