diff --git a/skfda/misc/covariances.py b/skfda/misc/covariances.py index 298fcf64d..00baa7e66 100644 --- a/skfda/misc/covariances.py +++ b/skfda/misc/covariances.py @@ -1,14 +1,18 @@ +"""Covariances module.""" from __future__ import annotations import abc -from typing import Callable, Sequence, Tuple, Union +from typing import Any, Callable, Sequence import matplotlib.pyplot as plt import numpy as np import sklearn.gaussian_process.kernels as sklearn_kern from matplotlib.figure import Figure +from numpy.typing import NDArray from scipy.special import gamma, kv +from ..misc._math import inner_product_matrix +from ..misc.metrics import PairwiseMetric, l2_distance from ..representation import FData, FDataBasis, FDataGrid from ..representation.basis import TensorBasis from ..typing._numpy import ArrayLike, NDArrayFloat @@ -20,11 +24,13 @@ def _squared_norms(x: NDArrayFloat, y: NDArrayFloat) -> NDArrayFloat: ).sum(2) -CovarianceLike = Union[ - float, - NDArrayFloat, - Callable[[ArrayLike, ArrayLike], NDArrayFloat], -] +CovarianceLike = ( + float + | NDArrayFloat + | Callable[[ArrayLike, ArrayLike], NDArrayFloat] +) + +Input = NDArray[Any] | FData def _transform_to_2d(t: ArrayLike) -> NDArrayFloat: @@ -51,31 +57,59 @@ def _execute_covariance( if isinstance(covariance, (int, float)): return np.array(covariance) + + if callable(covariance): + result = covariance(x, y) + elif isinstance(covariance, np.ndarray): + result = covariance else: - if callable(covariance): - result = covariance(x, y) - elif isinstance(covariance, np.ndarray): - result = covariance - else: - # GPy kernel - result = covariance.K(x, y) + # GPy kernel + result = covariance.K(x, y) - assert result.shape[0] == len(x) - assert result.shape[1] == len(y) - return result + assert result.shape[0] == len(x) + assert result.shape[1] == len(y) + return result class Covariance(abc.ABC): """Abstract class for covariance functions.""" - _parameters_str: Sequence[Tuple[str, str]] + _parameters_str: Sequence[tuple[str, str]] _latex_formula: str @abc.abstractmethod - def __call__(self, x: ArrayLike, y: ArrayLike) -> NDArrayFloat: + def __call__( + self, + x: Input, + y: Input | None = None, + ) -> NDArrayFloat: + """Compute covariance function on input data.""" pass - def heatmap(self, limits: Tuple[float, float] = (-1, 1)) -> Figure: + def _param_check_and_transform( + self, + x: Input, + y: Input | None = None, + ) -> tuple[Input, Input]: + # Param check + if y is None: + y = x + + if type(x) is not type(y): # noqa: WPS516 + raise ValueError( + 'Cannot operate objects x and y from different classes', + f'({type(x)}, {type(y)}).', + ) + + if not isinstance(x, FData) and not isinstance(y, FData): + if len(x.shape) < 2: + x = np.atleast_2d(x) + if len(y.shape) < 2: + y = np.atleast_2d(y) + + return x, y + + def heatmap(self, limits: tuple[float, float] = (-1, 1)) -> Figure: """Return a heatmap plot of the covariance function.""" from ..exploratory.visualization._utils import _create_figure @@ -147,7 +181,10 @@ def _repr_html_(self) -> str: row_style = '' - def column_style(percent: float, margin_top: str = "0") -> str: + def column_style( # noqa: WPS430 + percent: float, + margin_top: str = "0", + ) -> str: return ( f'style="display: inline-block; ' f'margin:0; ' @@ -171,7 +208,7 @@ def column_style(percent: float, margin_top: str = "0") -> str: {heatmap} - """ + """ # noqa: WPS432, WPS318 def to_sklearn(self) -> sklearn_kern.Kernel: """Convert it to a sklearn kernel, if there is one.""" @@ -231,6 +268,7 @@ class Brownian(Covariance): Brownian() """ + _latex_formula = ( r"K(x, x') = \sigma^2 \frac{|x - \mathcal{O}| + " r"|x' - \mathcal{O}| - |x - x'|}{2}" @@ -245,9 +283,22 @@ def __init__(self, *, variance: float = 1, origin: float = 0) -> None: self.variance = variance self.origin = origin - def __call__(self, x: ArrayLike, y: ArrayLike) -> NDArrayFloat: - x = _transform_to_2d(x) - self.origin - y = _transform_to_2d(y) - self.origin + def __call__( + self, + x: NDArray[Any], + y: NDArray[Any] | None = None, + ) -> NDArrayFloat: + """Compute Brownian covariance function on input data.""" + if isinstance(x, FData) or isinstance(y, FData): + raise ValueError( + 'Brownian covariance not defined for FData objects.', + ) + + x = _transform_to_2d(x) + y = _transform_to_2d(y) + + x = x - self.origin + y = y - self.origin sum_norms = np.add.outer( np.linalg.norm(x, axis=-1), @@ -319,13 +370,19 @@ def __init__(self, *, variance: float = 1, intercept: float = 0) -> None: self.variance = variance self.intercept = intercept - def __call__(self, x: ArrayLike, y: ArrayLike) -> NDArrayFloat: - x = _transform_to_2d(x) - y = _transform_to_2d(y) + def __call__( + self, + x: Input, + y: Input | None = None, + ) -> NDArrayFloat: + """Compute linear covariance function on input data.""" + x, y = self._param_check_and_transform(x, y) - return self.variance * (x @ y.T + self.intercept) + x_y = inner_product_matrix(x, y) + return self.variance * (x_y + self.intercept) def to_sklearn(self) -> sklearn_kern.Kernel: + """Obtain corresponding scikit-learn kernel type.""" return ( self.variance * (sklearn_kern.DotProduct(0) + self.intercept) @@ -400,16 +457,22 @@ def __init__( self.slope = slope self.degree = degree - def __call__(self, x: ArrayLike, y: ArrayLike) -> NDArrayFloat: - x = _transform_to_2d(x) - y = _transform_to_2d(y) + def __call__( + self, + x: Input, + y: Input | None = None, + ) -> NDArrayFloat: + """Compute polynomial covariance function on input data.""" + x, y = self._param_check_and_transform(x, y) + x_y = inner_product_matrix(x, y) return ( self.variance - * (self.slope * x @ y.T + self.intercept) ** self.degree + * (self.slope * x_y + self.intercept) ** self.degree ) def to_sklearn(self) -> sklearn_kern.Kernel: + """Obtain corresponding scikit-learn kernel type.""" return ( self.variance * (self.slope * sklearn_kern.DotProduct(0) + self.intercept) @@ -475,15 +538,21 @@ def __init__(self, *, variance: float = 1, length_scale: float = 1): self.variance = variance self.length_scale = length_scale - def __call__(self, x: ArrayLike, y: ArrayLike) -> NDArrayFloat: - x = _transform_to_2d(x) - y = _transform_to_2d(y) - - x_y = _squared_norms(x, y) - - return self.variance * np.exp(-x_y / (2 * self.length_scale ** 2)) + def __call__( + self, + x: Input, + y: Input | None = None, + ) -> NDArrayFloat: + """Compute Gaussian covariance function on input data.""" + x, y = self._param_check_and_transform(x, y) + + distance_x_y = PairwiseMetric(l2_distance)(x, y) + return self.variance * np.exp( # type: ignore[no-any-return] + -distance_x_y ** 2 / (2 * self.length_scale ** 2), + ) def to_sklearn(self) -> sklearn_kern.Kernel: + """Obtain corresponding scikit-learn kernel type.""" return ( self.variance * sklearn_kern.RBF(length_scale=self.length_scale) ) @@ -552,14 +621,22 @@ def __init__( self.variance = variance self.length_scale = length_scale - def __call__(self, x: ArrayLike, y: ArrayLike) -> NDArrayFloat: - x = _transform_to_2d(x) - y = _transform_to_2d(y) - - x_y = _squared_norms(x, y) - return self.variance * np.exp(-np.sqrt(x_y) / (self.length_scale)) + def __call__( + self, + x: Input, + y: Input | None = None, + ) -> NDArrayFloat: + """Compute exponential covariance function on input data.""" + x, y = self._param_check_and_transform(x, y) + + distance_x_y = PairwiseMetric(l2_distance)(x, y) + return self.variance * np.exp( # type: ignore[no-any-return] + -distance_x_y + / (self.length_scale), + ) def to_sklearn(self) -> sklearn_kern.Kernel: + """Obtain corresponding scikit-learn kernel type.""" return ( self.variance * sklearn_kern.Matern(length_scale=self.length_scale, nu=0.5) @@ -623,11 +700,20 @@ class WhiteNoise(Covariance): def __init__(self, *, variance: float = 1): self.variance = variance - def __call__(self, x: ArrayLike, y: ArrayLike) -> NDArrayFloat: + def __call__( + self, + x: Input, + y: Input | None = None, + ) -> NDArrayFloat: + """Compute white noise covariance function on input data.""" + if isinstance(x, FData) or isinstance(y, FData): + raise ValueError('Not defined for FData objects.') + x = _transform_to_2d(x) return self.variance * np.eye(x.shape[0]) def to_sklearn(self) -> sklearn_kern.Kernel: + """Obtain corresponding scikit-learn kernel type.""" return sklearn_kern.WhiteKernel(noise_level=self.variance) @@ -644,7 +730,7 @@ class Matern(Covariance): where :math:`\sigma^2` is the variance, :math:`l` is the length scale and :math:`\nu` controls the smoothness of the related Gaussian process. - The trajectories of a Gaussian process with Matérn covariance is + The trajectories of a Gaussian process with Matérn covariance is :math:`\lceil \nu \rceil - 1` times differentiable. @@ -680,6 +766,7 @@ class Matern(Covariance): Matern() """ + _latex_formula = ( r"K(x, x') = \sigma^2 \frac{2^{1-\nu}}{\Gamma(\nu)}" r"\left( \frac{\sqrt{2\nu}|x - x'|}{l} \right)^{\nu}" @@ -703,18 +790,21 @@ def __init__( self.length_scale = length_scale self.nu = nu - def __call__(self, x: ArrayLike, y: ArrayLike) -> NDArrayFloat: - x = _transform_to_2d(x) - y = _transform_to_2d(y) + def __call__( + self, + x: Input, + y: Input | None = None, + ) -> NDArrayFloat: + """Compute Matern covariance function on input data.""" + x, y = self._param_check_and_transform(x, y) - x_y_squared = _squared_norms(x, y) - x_y = np.sqrt(x_y_squared) + distance_x_y = PairwiseMetric(l2_distance)(x, y) p = self.nu - 0.5 if p.is_integer(): # Formula for half-integers p = int(p) - body = np.sqrt(2 * p + 1) * x_y / self.length_scale + body = np.sqrt(2 * p + 1) * distance_x_y / self.length_scale exponential = np.exp(-body) power_list = np.full(shape=(p,) + body.shape, fill_value=2 * body) power_list = np.cumprod(power_list, axis=0) @@ -734,28 +824,29 @@ def __call__(self, x: ArrayLike, y: ArrayLike) -> NDArrayFloat: self.variance * exponential * np.sum(sum_terms, axis=-1) ) elif self.nu == np.inf: - return ( + return ( # type: ignore[no-any-return] self.variance * np.exp( - -x_y_squared / (2 * self.length_scale ** 2), + -distance_x_y ** 2 / (2 * self.length_scale ** 2), ) ) - else: - # General formula - scaling = 2**(1 - self.nu) / gamma(self.nu) - body = np.sqrt(2 * self.nu) * x_y / self.length_scale - power = body**self.nu - bessel = kv(self.nu, body) - - with np.errstate(invalid='ignore'): - eval_cov = self.variance * scaling * power * bessel - - # Values with nan are where the distance is 0 - return np.nan_to_num( # type: ignore[no-any-return] - eval_cov, - nan=self.variance, - ) + + # General formula + scaling = 2**(1 - self.nu) / gamma(self.nu) + body = np.sqrt(2 * self.nu) * distance_x_y / self.length_scale + power = body**self.nu + bessel = kv(self.nu, body) + + with np.errstate(invalid='ignore'): + eval_cov = self.variance * scaling * power * bessel + + # Values with nan are where the distance is 0 + return np.nan_to_num( # type: ignore[no-any-return] + eval_cov, + nan=self.variance, + ) def to_sklearn(self) -> sklearn_kern.Kernel: + """Obtain corresponding scikit-learn kernel type.""" return ( self.variance * sklearn_kern.Matern(length_scale=self.length_scale, nu=self.nu) @@ -798,7 +889,11 @@ def __init__(self, data: FData) -> None: "for univariate functions", ) - def __call__(self, x: ArrayLike, y: ArrayLike) -> NDArrayFloat: + def __call__( + self, + x: Input, + y: Input | None = None, + ) -> NDArrayFloat: """Evaluate the covariance function. Args: @@ -808,6 +903,9 @@ def __call__(self, x: ArrayLike, y: ArrayLike) -> NDArrayFloat: Returns: Covariance function evaluated at the grid formed by x and y. """ + if isinstance(x, FData) or isinstance(y, FData): + raise ValueError('Not defined for FData objects.') + return self.cov_fdata([x, y], grid=True)[0, ..., 0] diff --git a/skfda/tests/test_covariances.py b/skfda/tests/test_covariances.py index 1005da49d..1e8550564 100644 --- a/skfda/tests/test_covariances.py +++ b/skfda/tests/test_covariances.py @@ -1,106 +1,241 @@ -import unittest +"""Tests for Covariance module.""" +from typing import Any import numpy as np - -import skfda.misc.covariances - - -class TestsSklearn(unittest.TestCase): - - def setUp(self) -> None: - unittest.TestCase.setUp(self) - - self.x = np.linspace(-1, 1, 1000)[:, np.newaxis] - - def _test_compare_sklearn( - self, - cov: skfda.misc.covariances.Covariance, - ) -> None: - cov_sklearn = cov.to_sklearn() - cov_matrix = cov(self.x, self.x) - cov_sklearn_matrix = cov_sklearn(self.x) - - np.testing.assert_array_almost_equal(cov_matrix, cov_sklearn_matrix) - - def test_linear(self) -> None: - - for variance in (1, 2): - for intercept in (0, 1, 2): - with self.subTest(variance=variance, intercept=intercept): - cov = skfda.misc.covariances.Linear( - variance=variance, intercept=intercept) - self._test_compare_sklearn(cov) - - def test_polynomial(self) -> None: - - # Test a couple of non-default parameters only for speed - for variance in (2,): - for intercept in (0, 2): - for slope in (1, 2): - for degree in (1, 2, 3): - with self.subTest( - variance=variance, - intercept=intercept, - slope=slope, - degree=degree, - ): - cov = skfda.misc.covariances.Polynomial( - variance=variance, - intercept=intercept, - slope=slope, - degree=degree, - ) - self._test_compare_sklearn(cov) - - def test_gaussian(self) -> None: - - for variance in (1, 2): - for length_scale in (0.5, 1, 2): - with self.subTest( - variance=variance, - length_scale=length_scale, - ): - cov = skfda.misc.covariances.Gaussian( - variance=variance, - length_scale=length_scale, - ) - self._test_compare_sklearn(cov) - - def test_exponential(self) -> None: - - for variance in (1, 2): - for length_scale in (0.5, 1, 2): - with self.subTest( - variance=variance, - length_scale=length_scale, - ): - cov = skfda.misc.covariances.Exponential( - variance=variance, - length_scale=length_scale, - ) - self._test_compare_sklearn(cov) - - def test_matern(self) -> None: - - # Test a couple of non-default parameters only for speed - for variance in (2,): - for length_scale in (0.5,): - for nu in (0.5, 1, 1.5, 2.5, 3.5, np.inf): - with self.subTest( - variance=variance, - length_scale=length_scale, - nu=nu, - ): - cov = skfda.misc.covariances.Matern( - variance=variance, - length_scale=length_scale, - nu=nu, - ) - self._test_compare_sklearn(cov) - - def test_white_noise(self) -> None: - - for variance in (1, 2): - with self.subTest(variance=variance): - cov = skfda.misc.covariances.WhiteNoise(variance=variance) - self._test_compare_sklearn(cov) +import pytest +from sklearn.model_selection import ParameterGrid + +import skfda.misc.covariances as cov +from skfda import FDataBasis, FDataGrid +from skfda.datasets import fetch_weather +from skfda.representation.basis import MonomialBasis + + +def _test_compare_sklearn( + multivariate_data: Any, + cov: cov.Covariance, +) -> None: + cov_sklearn = cov.to_sklearn() + cov_matrix = cov(multivariate_data) + cov_sklearn_matrix = cov_sklearn(multivariate_data) + + np.testing.assert_array_almost_equal(cov_matrix, cov_sklearn_matrix) + +############################################################################### +# Example datasets for which to calculate the evaluation of the kernel by hand +# to compare it against the results yielded by the implementation. +############################################################################### + + +basis = MonomialBasis(n_basis=2, domain_range=(-2, 2)) + +fd = [ + FDataBasis(basis=basis, coefficients=[[1, 0], [1, 2]]), + FDataBasis(basis=basis, coefficients=[[0, 1]]), +] + +############################################################################## +# Fixtures +############################################################################## + + +@pytest.fixture +def fetch_weather_subset() -> FDataGrid: + """Fixture for loading the canadian weather dataset example.""" + fd, _ = fetch_weather(return_X_y=True) + return fd[:20] + + +@pytest.fixture( + params=[ + cov.Linear(), + cov.Polynomial(), + cov.Gaussian(), + cov.Exponential(), + cov.Matern(), + ], +) +def covariances_fixture(request: Any) -> Any: + """Fixture for getting a covariance kernel function.""" + return request.param + + +@pytest.fixture( + params=[ + cov.Brownian(), + cov.WhiteNoise(), + ], +) +def covariances_raise_fixture(request: Any) -> Any: + """Fixture for getting a covariance kernel that raises a ValueError.""" + return request.param + + +@pytest.fixture( + params=[ + [ + cov.Linear(variance=1 / 2, intercept=3), + np.array([[3 / 2], [3 / 2 + 32 / 6]]), + ], + [ + cov.Polynomial(variance=1 / 3, slope=2, intercept=1, degree=2), + np.array([[1 / 3], [67**2 / 3**3]]), + ], + [ + cov.Gaussian(variance=3, length_scale=2), + np.array([[3 * np.exp(-7 / 6)], [3 * np.exp(-7 / 6)]]), + ], + [ + cov.Exponential(variance=4, length_scale=5), + np.array([ + [4 * np.exp(-np.sqrt(28 / 3) / 5)], + [4 * np.exp(-np.sqrt(28 / 3) / 5)], + ]), + ], + [ + cov.Matern(variance=2, length_scale=3, nu=2), + np.array([ + [(2 / 3) ** 2 * (28 / 3) * 0.239775899566], + [(2 / 3) ** 2 * (28 / 3) * 0.239775899566], + ]), + ], + ], +) +def precalc_example_data( + request: Any, +) -> list[FDataBasis, FDataBasis, cov.Covariance, np.array]: + """Fixture for getting fdatabasis objects. + + The dataset is used to test manual calculations of the covariance functions + against the implementation. + """ + # First fd, Second fd, kernel used, result + return *fd, *request.param + + +@pytest.fixture +def multivariate_data() -> np.array: + """Fixture for getting multivariate data.""" + return np.linspace(-1, 1, 1000)[:, np.newaxis] + + +@pytest.fixture( + params=[ + [ + cov.Linear, + { + "variance": [1, 2], + "intercept": [3, 4], + }, + ], + [ + cov.Polynomial, + { + "variance": [2], + "intercept": [0, 2], + "slope": [1, 2], + "degree": [1, 2, 3], + }, + ], + [ + cov.Exponential, + { + "variance": [1, 2], + "length_scale": [0.5, 1, 2], + }, + ], + [ + cov.Gaussian, + { + "variance": [1, 2], + "length_scale": [0.5, 1, 2], + }, + ], + [ + cov.Matern, + { + "variance": [2], + "length_scale": [0.5], + "nu": [0.5, 1, 1.5, 2.5, 3.5, np.inf], + }, + ], + ], +) +def covariance_and_params(request: Any) -> Any: + """Fixture to load the covariance functions.""" + return request.param + + +############################################################################## +# Tests +############################################################################## + + +def test_covariances( + fetch_weather_subset: FDataGrid, + covariances_fixture: cov.Covariance, +) -> None: + """Check that parameter conversion is done correctly.""" + fd = fetch_weather_subset + cov_kernel = covariances_fixture + + # Also test that it does not fail + res1 = cov_kernel(fd, fd) + res2 = cov_kernel(fd) + + np.testing.assert_allclose( + res1, + res2, + atol=1e-7, + ) + + +def test_raises( + fetch_weather_subset: FDataGrid, + covariances_raise_fixture: Any, +) -> None: + """Check raises ValueError. + + Check that non-functional kernels raise a ValueError exception + with functional data. + """ + fd = fetch_weather_subset + cov_kernel = covariances_raise_fixture + + pytest.raises( + ValueError, + cov_kernel, + fd, + ) + + +def test_precalc_example( + precalc_example_data: list[ # noqa: WPS320 + FDataBasis, FDataBasis, cov.Covariance, np.array, + ], +): + """Check the precalculated example for Linear covariance kernel. + + Compare the theoretical precalculated results against the covariance kernel + implementation, for different kernels. + The structure of the input is a list containing: + [First functional dataset, Second functional dataset, + Covariance kernel used, Result] + """ + fd1, fd2, kernel, precalc_result = precalc_example_data + computed_result = kernel(fd1, fd2) + np.testing.assert_allclose( + computed_result, + precalc_result, + rtol=1e-6, + ) + + +def test_multivariate_covariance_kernel( + multivariate_data: np.array, + covariance_and_params: Any, +) -> None: + """Test general covariance kernel against scikit-learn's kernel.""" + cov_kernel, param_dict = covariance_and_params + for input_params in list(ParameterGrid(param_dict)): + _test_compare_sklearn(multivariate_data, cov_kernel(**input_params))