diff --git a/amlb/datautils.py b/amlb/datautils.py index 94aab0981..1d7dc42fb 100644 --- a/amlb/datautils.py +++ b/amlb/datautils.py @@ -325,7 +325,9 @@ def impute_array( return imputed -def impute_dataframe(X_fit: pd.DataFrame, *X_s: pd.DataFrame, missing_values=np.NaN, strategy='mean'): +def impute_dataframe(X_fit: pd.DataFrame, *X_s: Iterable[pd.DataFrame], missing_values: Any=np.NaN, + strategy: Literal['mean','median','mode'] | Tuple[Literal['constant'], Any] ='mean' + ) -> pd.DataFrame | list[pd.DataFrame]: """ :param X_fit: used to fit the imputer. This dataframe is also imputed. :param X_s: the additional (optional) dataframe that are imputed using the same imputer. @@ -349,29 +351,35 @@ def impute_dataframe(X_fit: pd.DataFrame, *X_s: pd.DataFrame, missing_values=np. return imputed if X_s else imputed[0] -def _impute_pd(X_fit, *X_s, missing_values=np.NaN, strategy=None, is_int=False): +def _impute_pd( + X_fit: pd.DataFrame, + *X_s: Iterable[pd.DataFrame], + missing_values: Any = np.NaN, + strategy: Literal['mean','median','mode'] | Tuple[Literal['constant'], Any] | None =None, + is_int: bool = False +) -> list[pd.DataFrame]: if strategy == 'mean': fill = X_fit.mean() elif strategy == 'median': fill = X_fit.median() elif strategy == 'mode': - fill = X_fit.mode().iloc[0, :] + fill = X_fit.mode().iloc[0, :] # type: ignore[call-overload] elif isinstance(strategy, tuple) and strategy[0] == 'constant': fill = strategy[1] else: - return [X_fit, *X_s] + return [X_fit, *X_s] # type: ignore[list-item] # doesn't seem to understand unpacking if is_int and isinstance(fill, pd.Series): fill = fill.round() return [df.replace(missing_values, fill) for df in [X_fit, *X_s]] -def _rows_with_nas(X): +def _rows_with_nas(X: np.ndarray | pd.DataFrame) -> pd.DataFrame: df = X if isinstance(X, pd.DataFrame) else pd.DataFrame(X) return df[df.isna().any(axis=1)] -def _restore_dtypes(X_np, X_ori): +def _restore_dtypes(X_np: np.ndarray, X_ori: pd.DataFrame | pd.Series | np.ndarray) -> pd.DataFrame | pd.Series | np.ndarray: if isinstance(X_ori, pd.DataFrame): df = pd.DataFrame(X_np, columns=X_ori.columns, index=X_ori.index).convert_dtypes() df.astype(X_ori.dtypes.to_dict(), copy=False, errors='raise') diff --git a/amlb/results.py b/amlb/results.py index 4cef9498e..b3991fac5 100644 --- a/amlb/results.py +++ b/amlb/results.py @@ -309,7 +309,7 @@ def save_predictions(dataset: Dataset, output_file: str, if probabilities is not None: prob_cols = probabilities_labels if probabilities_labels else dataset.target.label_encoder.classes - df = to_data_frame(probabilities, columns=prob_cols) + df = to_data_frame(probabilities, column_names=prob_cols) if probabilities_labels is not None: df = df[sort(prob_cols)] # reorder columns alphabetically: necessary to match label encoding if any(prob_cols != df.columns.values):