diff --git a/skfda/representation/_functional_data.py b/skfda/representation/_functional_data.py index ee8813bbc..dda4a80c2 100644 --- a/skfda/representation/_functional_data.py +++ b/skfda/representation/_functional_data.py @@ -882,6 +882,7 @@ def mean( out: None = None, keepdims: bool = False, skipna: bool = False, + min_count: int = 0, ) -> T: """Compute the mean of all the samples. @@ -891,6 +892,9 @@ def mean( out: Used for compatibility with numpy. Must be None. keepdims: Used for compatibility with numpy. Must be False. skipna: Wether the NaNs are ignored or not. + min_count: Number of valid (non NaN) data to have in order + for the a variable to not be NaN when `skipna` is + `True`. Returns: A FData object with just one sample representing @@ -902,10 +906,7 @@ def mean( "Not implemented for that parameter combination", ) - return ( - self.sum(axis=axis, out=out, keepdims=keepdims, skipna=skipna) - / self.n_samples - ) + return self @abstractmethod def to_grid( diff --git a/skfda/representation/basis/_fdatabasis.py b/skfda/representation/basis/_fdatabasis.py index b528e3098..77b4a9150 100644 --- a/skfda/representation/basis/_fdatabasis.py +++ b/skfda/representation/basis/_fdatabasis.py @@ -427,20 +427,53 @@ def sum( # noqa: WPS125 """ super().sum(axis=axis, out=out, keepdims=keepdims, skipna=skipna) - coefs = ( - np.nansum(self.coefficients, axis=0) if skipna - else np.sum(self.coefficients, axis=0) - ) - - if min_count > 0: - valid = ~np.isnan(self.coefficients) - n_valid = np.sum(valid, axis=0) - coefs[n_valid < min_count] = np.nan + valid_functions = ~self.isna() + valid_coefficients = self.coefficients[valid_functions] + + coefs = np.sum(valid_coefficients, axis=0) return self.copy( coefficients=coefs, sample_names=(None,), ) + + def mean( # noqa: WPS125 + self: T, + *, + axis: Optional[int] = None, + dtype: None = None, + out: None = None, + keepdims: bool = False, + skipna: bool = False, + min_count: int = 0, + ) -> T: + """Compute the mean of all the samples. + + Args: + axis: Used for compatibility with numpy. Must be None or 0. + dtype: Used for compatibility with numpy. Must be None. + out: Used for compatibility with numpy. Must be None. + keepdims: Used for compatibility with numpy. Must be False. + skipna: Wether the NaNs are ignored or not. + min_count: Ignored, used for compatibility with FDataGrid + and FDataIrregular. + + Returns: + A FDataBasis object with just one sample representing + the mean of all the samples in the original object. + """ + super().mean(axis=axis, dtype=dtype, out=out, keepdims=keepdims, + skipna=skipna) + + return ( + self.sum( + axis=axis, + out=out, + keepdims=keepdims, + skipna=skipna, + ) + / np.sum(~self.isna()), + ) def var( self: T, @@ -998,7 +1031,7 @@ def isna(self) -> NDArrayBool: Returns: na_values (np.ndarray): Positions of NA. """ - return np.all( # type: ignore[no-any-return] + return np.any( # type: ignore[no-any-return] np.isnan(self.coefficients), axis=1, ) diff --git a/skfda/representation/grid.py b/skfda/representation/grid.py index 50bb96169..315d2a6d8 100644 --- a/skfda/representation/grid.py +++ b/skfda/representation/grid.py @@ -544,6 +544,60 @@ def _get_points_and_values(self: T) -> Tuple[NDArrayFloat, NDArrayFloat]: def _get_input_points(self: T) -> GridPoints: return self.grid_points + + def _compute_aggregate( + self: T, + operation: str, + *, + skipna: bool = False, + min_count: int = 0, + ) -> T: + """Compute a defined aggregation operation of the samples. + + Args: + operation: Operation to be performed. Can be 'mean', 'sum' or + 'var'. + axis: Used for compatibility with numpy. Must be None or 0. + out: Used for compatibility with numpy. Must be None. + keepdims: Used for compatibility with numpy. Must be False. + skipna: Wether the NaNs are ignored or not. + min_count: Number of valid (non NaN) data to have in order + for the a variable to not be NaN when `skipna` is + `True`. + + Returns: + An FDataGrid object with just one sample representing + the aggregation of all the samples in the original object. + + """ + if operation not in {'sum', 'mean', 'var'}: + raise ValueError("Invalid operation." + "Must be one of 'sum', 'mean', or 'var'.") + + if skipna: + agg_func = { + 'sum': np.nansum, + 'mean': np.nanmean, + 'var': np.nanvar + }[operation] + else: + agg_func = { + 'sum': np.sum, + 'mean': np.mean, + 'var': np.var + }[operation] + + data = agg_func(self.data_matrix, axis=0, keepdims=True) + + if min_count > 0 and skipna: + valid = ~np.isnan(self.data_matrix) + n_valid = np.sum(valid, axis=0) + data[n_valid < min_count] = np.nan + + return self.copy( + data_matrix=data, + sample_names=(None,), + ) def sum( # noqa: WPS125 self: T, @@ -583,19 +637,50 @@ def sum( # noqa: WPS125 """ super().sum(axis=axis, out=out, keepdims=keepdims, skipna=skipna) - data = ( - np.nansum(self.data_matrix, axis=0, keepdims=True) if skipna - else np.sum(self.data_matrix, axis=0, keepdims=True) + return self._compute_aggregate( + operation='sum', + skipna=skipna, + min_count=min_count, ) - if min_count > 0: - valid = ~np.isnan(self.data_matrix) - n_valid = np.sum(valid, axis=0) - data[n_valid < min_count] = np.nan + def mean( # noqa: WPS125 + self: T, + *, + axis: Optional[int] = None, + dtype: None = None, + out: None = None, + keepdims: bool = False, + skipna: bool = False, + min_count: int = 0, + ) -> T: + """Compute the mean of all the samples. - return self.copy( - data_matrix=data, - sample_names=(None,), + Args: + axis: Used for compatibility with numpy. Must be None or 0. + dtype: Used for compatibility with numpy. Must be None. + out: Used for compatibility with numpy. Must be None. + keepdims: Used for compatibility with numpy. Must be False. + skipna: Wether the NaNs are ignored or not. + min_count: Number of valid (non NaN) data to have in order + for the a variable to not be NaN when `skipna` is + `True`. + + Returns: + A FDataGrid object with just one sample representing + the mean of all the samples in the original object. + """ + super().mean( + axis=axis, + dtype=dtype, + out=out, + keepdims=keepdims, + skipna=skipna, + ) + + return self._compute_aggregate( + operation='mean', + skipna=skipna, + min_count=min_count, ) def var(self: T, correction: int = 0) -> T: diff --git a/skfda/representation/irregular.py b/skfda/representation/irregular.py index cf19c8cae..dfa916230 100644 --- a/skfda/representation/irregular.py +++ b/skfda/representation/irregular.py @@ -716,6 +716,65 @@ def sum( # noqa: WPS125 values=sum_values, sample_names=(None,), ) + + def mean( # noqa: WPS125 + self: T, + *, + axis: Optional[int] = None, + dtype: None = None, + out: None = None, + keepdims: bool = False, + skipna: bool = False, + min_count: int = 0, + ) -> T: + """Compute the mean of all the samples. + + Args: + axis: Used for compatibility with numpy. Must be None or 0. + dtype: Used for compatibility with numpy. Must be None. + out: Used for compatibility with numpy. Must be None. + keepdims: Used for compatibility with numpy. Must be False. + skipna: Wether the NaNs are ignored or not. + min_count: Number of valid (non NaN) data to have in order + for the a variable to not be NaN when `skipna` is + `True`. + + Returns: + An FDataIrregular object with just one sample representing + the mean of all the samples in the original object. + """ + super().mean( + axis=axis, + dtype=dtype, + out=out, + keepdims=keepdims, + skipna=skipna, + ) + + common_points, common_values = self._get_common_points_and_values() + + if len(common_points) == 0: + raise ValueError("No common points in FDataIrregular object") + + sum_function = np.nansum if skipna else np.sum + sum_values = sum_function(common_values, axis=0) + + if skipna: + count_values = np.sum(~np.isnan(common_values), axis=0) + else: + count_values = np.full(sum_values.shape, self.n_samples) + + if min_count > 0 and skipna: + count_values[count_values < min_count] = np.nan + + mean_values = sum_values / count_values + + return FDataIrregular( + start_indices=np.array([0]), + points=common_points, + values=mean_values, + sample_names=(None,), + ) def var(self: T, correction: int = 0) -> T: """Compute the variance of all the samples.