Skip to content

Commit

Permalink
Add type hints for imputation
Browse files Browse the repository at this point in the history
  • Loading branch information
PGijsbers committed Sep 7, 2024
1 parent a52bdf8 commit 23f0ca7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
20 changes: 14 additions & 6 deletions amlb/datautils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,9 @@ def impute_array(
return imputed


def impute_dataframe(X_fit: pd.DataFrame, *X_s: pd.DataFrame, missing_values=np.NaN, strategy='mean'):
def impute_dataframe(X_fit: pd.DataFrame, *X_s: Iterable[pd.DataFrame], missing_values: Any=np.NaN,
strategy: Literal['mean','median','mode'] | Tuple[Literal['constant'], Any] ='mean'
) -> pd.DataFrame | list[pd.DataFrame]:
"""
:param X_fit: used to fit the imputer. This dataframe is also imputed.
:param X_s: the additional (optional) dataframe that are imputed using the same imputer.
Expand All @@ -349,29 +351,35 @@ def impute_dataframe(X_fit: pd.DataFrame, *X_s: pd.DataFrame, missing_values=np.
return imputed if X_s else imputed[0]


def _impute_pd(X_fit, *X_s, missing_values=np.NaN, strategy=None, is_int=False):
def _impute_pd(
X_fit: pd.DataFrame,
*X_s: Iterable[pd.DataFrame],
missing_values: Any = np.NaN,
strategy: Literal['mean','median','mode'] | Tuple[Literal['constant'], Any] | None =None,
is_int: bool = False
) -> list[pd.DataFrame]:
if strategy == 'mean':
fill = X_fit.mean()
elif strategy == 'median':
fill = X_fit.median()
elif strategy == 'mode':
fill = X_fit.mode().iloc[0, :]
fill = X_fit.mode().iloc[0, :] # type: ignore[call-overload]
elif isinstance(strategy, tuple) and strategy[0] == 'constant':
fill = strategy[1]
else:
return [X_fit, *X_s]
return [X_fit, *X_s] # type: ignore[list-item] # doesn't seem to understand unpacking

if is_int and isinstance(fill, pd.Series):
fill = fill.round()
return [df.replace(missing_values, fill) for df in [X_fit, *X_s]]


def _rows_with_nas(X):
def _rows_with_nas(X: np.ndarray | pd.DataFrame) -> pd.DataFrame:
df = X if isinstance(X, pd.DataFrame) else pd.DataFrame(X)
return df[df.isna().any(axis=1)]


def _restore_dtypes(X_np, X_ori):
def _restore_dtypes(X_np: np.ndarray, X_ori: pd.DataFrame | pd.Series | np.ndarray) -> pd.DataFrame | pd.Series | np.ndarray:
if isinstance(X_ori, pd.DataFrame):
df = pd.DataFrame(X_np, columns=X_ori.columns, index=X_ori.index).convert_dtypes()
df.astype(X_ori.dtypes.to_dict(), copy=False, errors='raise')
Expand Down
2 changes: 1 addition & 1 deletion amlb/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def save_predictions(dataset: Dataset, output_file: str,

if probabilities is not None:
prob_cols = probabilities_labels if probabilities_labels else dataset.target.label_encoder.classes
df = to_data_frame(probabilities, columns=prob_cols)
df = to_data_frame(probabilities, column_names=prob_cols)
if probabilities_labels is not None:
df = df[sort(prob_cols)] # reorder columns alphabetically: necessary to match label encoding
if any(prob_cols != df.columns.values):
Expand Down

0 comments on commit 23f0ca7

Please sign in to comment.