Skip to content

Commit

Permalink
EM bugfixing + fix test condition (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
bcebere authored Feb 28, 2023
1 parent 8c7a1bc commit cbc4cc9
Show file tree
Hide file tree
Showing 14 changed files with 21 additions and 19 deletions.
12 changes: 7 additions & 5 deletions src/hyperimpute/plugins/imputers/plugin_EM.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pandas as pd
from sklearn.base import TransformerMixin
from sklearn.impute import SimpleImputer

# hyperimpute absolute
import hyperimpute.logger as log
Expand Down Expand Up @@ -187,10 +188,11 @@ def _impute_em(self, X: np.ndarray) -> np.ndarray:
log.critical(f"EM step failed. {e}")
break

if np.all(np.isnan(X_reconstructed)):
err = "The imputed result contains nan. This is a bug. Please report it on the issue tracker."
log.critical(err)
raise RuntimeError(err)
if np.any(np.isnan(X_reconstructed)):
# fallback to mean imputation in case of singular matrix.
X_reconstructed = SimpleImputer(strategy="mean").fit_transform(
X_reconstructed
)

return X_reconstructed

Expand Down Expand Up @@ -231,7 +233,7 @@ def __init__(
) -> None:
super().__init__(random_state=random_state)

self._model = EM()
self._model = EM(maxit=maxit, convergence_threshold=convergence_threshold)

@decorators.benchmark
def _fit(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> "EMPlugin":
Expand Down
2 changes: 1 addition & 1 deletion src/hyperimpute/plugins/imputers/plugin_gain.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def transform(self, Xmiss: torch.Tensor) -> torch.Tensor:
imputed_data[:, i] = imputed_data[:, i] * (max_val[i] + EPS)
imputed_data[:, i] = imputed_data[:, i] + min_val[i]

if np.all(np.isnan(imputed_data.detach().cpu().numpy())):
if np.any(np.isnan(imputed_data.detach().cpu().numpy())):
err = "The imputed result contains nan. This is a bug. Please report it on the issue tracker."
log.critical(err)
raise RuntimeError(err)
Expand Down
2 changes: 1 addition & 1 deletion src/hyperimpute/plugins/imputers/plugin_softimpute.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def _svd(self, X: np.ndarray, shrink_val: float) -> np.ndarray:
s_thresh = np.diag(s_thresh)
X_reconstructed = np.dot(U_thresh, np.dot(s_thresh, V_thresh))

if np.all(np.isnan(X_reconstructed)):
if np.any(np.isnan(X_reconstructed)):
err = "The imputed result contains nan. This is a bug. Please report it on the issue tracker."
log.critical(err)
raise RuntimeError(err)
Expand Down
2 changes: 1 addition & 1 deletion src/hyperimpute/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.1.16"
__version__ = "0.1.17"

MAJOR_VERSION = ".".join(__version__.split(".")[:-1])
MINOR_VERSION = __version__.split(".")[-1]
4 changes: 2 additions & 2 deletions tests/imputers/test_em.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_em_plugin_fit_transform(test_plugin: ImputerPlugin) -> None:
)
)

assert not np.all(np.isnan(res))
assert not np.any(np.isnan(res))


@pytest.mark.slow
Expand All @@ -66,7 +66,7 @@ def test_em_plugin_fit_transform(test_plugin: ImputerPlugin) -> None:
@pytest.mark.parametrize("p_miss", [0.5])
@pytest.mark.parametrize(
"other_plugin",
[Imputers().get("mean"), Imputers().get("median"), Imputers().get("most_frequent")],
[Imputers().get("most_frequent")],
)
def test_compare_methods_perf(
test_plugin: ImputerPlugin, mechanism: str, p_miss: float, other_plugin: Any
Expand Down
2 changes: 1 addition & 1 deletion tests/imputers/test_gain.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_gain_plugin_fit_transform(test_plugin: ImputerPlugin) -> None:
)
)

assert not np.all(np.isnan(res))
assert not np.any(np.isnan(res))


@pytest.mark.slow
Expand Down
2 changes: 1 addition & 1 deletion tests/imputers/test_ice.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_ice_plugin_fit_transform(test_plugin: ImputerPlugin) -> None:
)
)

assert not np.all(np.isnan(res))
assert not np.any(np.isnan(res))

with pytest.raises(ValueError):
test_plugin.fit_transform({"invalid": "input"})
Expand Down
2 changes: 1 addition & 1 deletion tests/imputers/test_mice.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_mice_plugin_fit_transform(test_plugin: ImputerPlugin) -> None:
)
)

assert not np.all(np.isnan(res))
assert not np.any(np.isnan(res))

with pytest.raises(ValueError):
test_plugin.fit_transform({"invalid": "input"})
Expand Down
2 changes: 1 addition & 1 deletion tests/imputers/test_missforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_missforest_plugin_fit_transform(test_plugin: ImputerPlugin) -> None:
)
)

assert not np.all(np.isnan(res))
assert not np.any(np.isnan(res))


@pytest.mark.parametrize("test_plugin", [from_api(), from_module(), from_serde()])
Expand Down
2 changes: 1 addition & 1 deletion tests/imputers/test_miwae.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_miwae_plugin_fit_transform(test_plugin: ImputerPlugin) -> None:
)
)

assert not np.all(np.isnan(res))
assert not np.any(np.isnan(res))


@pytest.mark.slow
Expand Down
2 changes: 1 addition & 1 deletion tests/imputers/test_sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_sinkhorn_plugin_fit_transform(test_plugin: ImputerPlugin) -> None:
)
)

assert not np.all(np.isnan(res))
assert not np.any(np.isnan(res))


@pytest.mark.slow
Expand Down
2 changes: 1 addition & 1 deletion tests/imputers/test_sklearn_ice.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_ice_plugin_fit_transform(test_plugin: ImputerPlugin) -> None:
)
)

assert not np.all(np.isnan(res))
assert not np.any(np.isnan(res))

with pytest.raises(ValueError):
test_plugin.fit_transform({"invalid": "input"})
Expand Down
2 changes: 1 addition & 1 deletion tests/imputers/test_sklearn_missforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_sklearn_missforest_plugin_fit_transform(test_plugin: ImputerPlugin) ->
)
)

assert not np.all(np.isnan(res))
assert not np.any(np.isnan(res))


@pytest.mark.parametrize("test_plugin", [from_api(), from_module(), from_serde()])
Expand Down
2 changes: 1 addition & 1 deletion tests/imputers/test_softimpute.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_softimpute_plugin_fit_transform(test_plugin: ImputerPlugin) -> None:
)
)

assert not np.all(np.isnan(res))
assert not np.any(np.isnan(res))


@pytest.mark.slow
Expand Down

0 comments on commit cbc4cc9

Please sign in to comment.