Skip to content

Commit

Permalink
Fix cartesian product.
Browse files Browse the repository at this point in the history
  • Loading branch information
vnmabus committed Nov 24, 2023
1 parent 29bb62c commit 3d7d58f
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 14 deletions.
3 changes: 0 additions & 3 deletions skfda/_utils/ndfunction/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,6 @@ def cartesian_product( # noqa: WPS234
)


_cartesian_product = cartesian_product


def grid_points_equal(gp1: GridPoints[A], gp2: GridPoints[A], /) -> bool:
"""Check if grid points are equal."""
shape_equal = gp1.shape == gp2.shape
Expand Down
4 changes: 2 additions & 2 deletions skfda/datasets/_samples_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from scipy.stats import multivariate_normal

from .._utils import _to_grid_points, normalize_warping
from .._utils.ndfunction.utils import _cartesian_product
from .._utils.ndfunction.utils import cartesian_product
from ..misc.covariances import Brownian, CovarianceLike, _execute_covariance
from ..misc.validation import validate_random_state
from ..representation import FDataGrid
Expand Down Expand Up @@ -66,7 +66,7 @@ def make_gaussian(

grid_points = _to_grid_points(grid_points)

input_points = _cartesian_product(grid_points)
input_points = cartesian_product(grid_points)

covariance = _execute_covariance(
cov,
Expand Down
6 changes: 3 additions & 3 deletions skfda/ml/regression/_historical_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ..._utils import _pairwise_symmetric
from ..._utils._sklearn_adapter import BaseEstimator, RegressorMixin
from ..._utils.ndfunction.utils import _cartesian_product
from ..._utils.ndfunction.utils import cartesian_product
from ...representation import FData, FDataBasis, FDataGrid
from ...representation.basis import (
Basis,
Expand Down Expand Up @@ -121,7 +121,7 @@ def _get_valid_points(
) -> NDArrayFloat:
"""Return the valid points as integer tuples."""
interval_points = np.arange(n_intervals + 1)
full_grid_points = _cartesian_product((interval_points, interval_points))
full_grid_points = cartesian_product((interval_points, interval_points))

past_points = full_grid_points[
full_grid_points[:, 0] <= full_grid_points[:, 1]
Expand Down Expand Up @@ -154,7 +154,7 @@ def _get_triangles(

interval_without_end = np.arange(n_intervals)

pts_coords = _cartesian_product(
pts_coords = cartesian_product(
(interval_without_end, interval_without_end),
)

Expand Down
6 changes: 3 additions & 3 deletions skfda/preprocessing/smoothing/_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing_extensions import Final

from ..._utils import _to_grid_points
from ..._utils.ndfunction.utils import _cartesian_product
from ..._utils.ndfunction.utils import cartesian_product
from ...misc.lstsq import LstsqMethod, solve_regularized_weighted_lstsq
from ...misc.regularization import L2Regularization
from ...representation import FData, FDataBasis, FDataGrid
Expand Down Expand Up @@ -234,7 +234,7 @@ def _coef_matrix(
from ...misc.regularization import compute_penalty_matrix

basis_values_input = self.basis(
_cartesian_product(_to_grid_points(input_points)),
cartesian_product(_to_grid_points(input_points)),
).reshape((self.basis.n_basis, -1)).T

penalty_matrix = compute_penalty_matrix(
Expand Down Expand Up @@ -262,7 +262,7 @@ def _hat_matrix(
output_points: GridPointsLike,
) -> NDArrayFloat:
basis_values_output = self.basis(
_cartesian_product(
cartesian_product(
_to_grid_points(output_points),
),
).reshape((self.basis.n_basis, -1)).T
Expand Down
6 changes: 3 additions & 3 deletions skfda/preprocessing/smoothing/_kernel_smoothers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np

from ..._utils import _to_grid_points
from ..._utils.ndfunction.utils import _cartesian_product
from ..._utils.ndfunction.utils import cartesian_product
from ...misc.hat_matrix import HatMatrix, NadarayaWatsonHatMatrix
from ...misc.metrics import PairwiseMetric, l2_distance
from ...typing._base import GridPointsLike
Expand Down Expand Up @@ -131,8 +131,8 @@ def _hat_matrix(
output_points: GridPointsLike,
) -> NDArrayFloat:

input_points = _cartesian_product(_to_grid_points(input_points))
output_points = _cartesian_product(_to_grid_points(output_points))
input_points = cartesian_product(_to_grid_points(input_points))
output_points = cartesian_product(_to_grid_points(output_points))

if self.kernel_estimator is None:
self.kernel_estimator = NadarayaWatsonHatMatrix()
Expand Down

0 comments on commit 3d7d58f

Please sign in to comment.