Skip to content

Commit

Permalink
✨ allow predict_quantified to return values by alias (#157)
Browse files Browse the repository at this point in the history
Simply add `return_alias_dict=True` and the return value will be a dict of quantifier aliases pointing to (prediction, confidence_or_uncertainty) tuples.

Closes #46
  • Loading branch information
MiWeiss authored Feb 16, 2023
1 parent 738ed10 commit b48c4bd
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 13 deletions.
20 changes: 20 additions & 0 deletions tests_unit/models_tests/test_lazy_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tensorflow as tf

import uncertainty_wizard as uwiz
from uncertainty_wizard.quantifiers import StandardDeviation

DUMMY_MODEL_PATH = "tmp/dummy_lazy_ensemble"

Expand Down Expand Up @@ -56,6 +57,25 @@ def test_dummy_in_main_process(self):
self.assertEqual(pred.shape, (10, 1))
self.assertEqual(std.shape, (10, 1))

def test_result_as_dict(self):
ensemble = uwiz.models.LazyEnsemble(
num_models=2, model_save_path=DUMMY_MODEL_PATH, default_num_processes=0
)
ensemble.create(create_function=create_dummy_atomic_model)
res = ensemble.predict_quantified(
x=np.ones((10, 1000)),
quantifier="std",
num_processes=0,
return_alias_dict=True,
)
assert isinstance(res, dict)
for alias in StandardDeviation().aliases():
assert alias in res
assert type(res[alias]) == tuple
assert len(res[alias]) == 2
assert res[alias][0].shape == (10, 1)
assert res[alias][1].shape == (10, 1)

def test_dummy_main_and_one_distinct_process_are_equivalent(self):
ensemble = uwiz.models.LazyEnsemble(
num_models=2, model_save_path=DUMMY_MODEL_PATH
Expand Down
54 changes: 53 additions & 1 deletion tests_unit/models_tests/test_sequential_stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uncertainty_wizard as uwiz
from uncertainty_wizard.internal_utils import UncertaintyWizardWarning
from uncertainty_wizard.models import StochasticSequential
from uncertainty_wizard.quantifiers import StandardDeviation
from uncertainty_wizard.quantifiers import MaxSoftmax, StandardDeviation, VariationRatio


class SequentialStochasticTest(TestCase):
Expand All @@ -17,6 +17,58 @@ def _dummy_model():
model.add(tf.keras.layers.Dropout(rate=0.5))
return model

@staticmethod
def _dummy_classifier():
model = StochasticSequential()
model.add(tf.keras.layers.Input(shape=1000))
model.add(tf.keras.layers.Dropout(rate=0.5))
model.add(tf.keras.layers.Dense(10, activation="softmax"))
# compile the model
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=[tf.keras.metrics.CategoricalAccuracy()],
)
return model

def test_result_as_dict(self):
model = self._dummy_classifier()
x = np.ones((10, 1000))
res = model.predict_quantified(
x=x,
quantifier=[
"MaxSoftmax",
VariationRatio(),
],
return_alias_dict=True,
)

self.assertTrue(isinstance(res, dict))
for key, values in res.items():
self.assertTrue(isinstance(key, str))
self.assertTrue(isinstance(values, tuple))
self.assertEqual(len(values), 2)
self.assertEqual(values[0].shape, (10,))
self.assertEqual(values[1].shape, (10,))

for q in [MaxSoftmax(), VariationRatio()]:
for a in q.aliases():
self.assertTrue(a in res.keys())

def test_return_type_default_multi_quant(self):
model = self._dummy_classifier()
x = np.ones((10, 1000))
res = model.predict_quantified(x=x, quantifier=["MaxSoftmax", VariationRatio()])
self.assertTrue(isinstance(res, list))
self.assertTrue(len(res), 2)

def test_return_type_default_single_quant(self):
model = self._dummy_classifier()
x = np.ones((10, 1000))
res = model.predict_quantified(x=x, quantifier="MaxSoftmax")
self.assertTrue(isinstance(res, tuple))
self.assertTrue(len(res), 2)

def test_predict_is_deterministic(self):
model = self._dummy_model()
y = model.predict(x=np.ones((10, 1000)))
Expand Down
28 changes: 22 additions & 6 deletions uncertainty_wizard/models/_stochastic/_abstract_stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def predict_quantified(
steps=None,
as_confidence: Union[None, bool] = None,
broadcaster: Broadcaster = None,
return_alias_dict: bool = False,
):
"""
Calculates predictions and uncertainties (or confidences) according to the passed quantifer(s).
Expand All @@ -277,7 +278,10 @@ def predict_quantified(
:param as_confidence: If true, uncertainties are multiplied by (-1),
if false, confidences are multiplied by (-1). Default: No transformations.
:param broadcaster: Sampling Related Dependencies. If None, the DefaultBroadcaster will be used.
:return: A tuple (predictions, uncertainties_or_confidences) if a single quantifier was
:param return_alias_dict: If true, the result is returned as a dictionary with the quantifier aliases as keys.
:return: If `return_alias_dict=True`, a dict with all quantifier aliases as keys
and (predictions, uncertainties_or_confidences) as values.
Otherwise (default), a tuple (predictions, uncertainties_or_confidences) if a single quantifier was
passed as string or instance, or a collection of such tuples if the passed quantifiers was an iterable.
"""
all_q, pp_q, sample_q, return_single_tuple = self._quantifiers_as_list(
Expand Down Expand Up @@ -313,17 +317,25 @@ def predict_quantified(
)

results = self._run_quantifiers(
as_confidence, point_prediction_scores, all_q, stochastic_scores
as_confidence,
point_prediction_scores,
all_q,
stochastic_scores,
as_dict=return_alias_dict,
)
if return_single_tuple:
if return_single_tuple and not return_alias_dict:
return results[0]
return results

@staticmethod
def _run_quantifiers(
as_confidence, point_prediction_scores, quantifiers, stochastic_scores
as_confidence,
point_prediction_scores,
quantifiers,
stochastic_scores,
as_dict=False,
):
results = []
results = dict() if as_dict else []
for q in quantifiers:
if q.takes_samples():
assert stochastic_scores is not None, (
Expand All @@ -341,7 +353,11 @@ def _run_quantifiers(
superv_scores = q.cast_conf_or_unc(
as_confidence=as_confidence, superv_scores=superv_scores
)
results.append((predictions, superv_scores))
if as_dict:
for alias in q.aliases():
results[alias] = (predictions, superv_scores)
else:
results.append((predictions, superv_scores))
return results

@staticmethod
Expand Down
27 changes: 21 additions & 6 deletions uncertainty_wizard/models/ensemble_utils/_lazy_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ def predict_quantified(
num_processes=None,
context=None,
models: Optional[Iterable[int]] = None,
return_alias_dict: bool = False,
):
"""
Utility function to make quantified predictions on numpy arrays.
Expand All @@ -486,8 +487,11 @@ def predict_quantified(
:param context: A contextmanager which prepares a newly crated process for execution
(e.g. by configuring the gpus). See class docstring for explanation of default values.
:param models: A list of model indices to use for prediction. Default: `None`(All models).
:return: A tuple (predictions, uncertainties_or_confidences) if a single quantifier was passed as string
or instance, or a collection of such tuples if the passed quantifiers was an iterable.
:param return_alias_dict: If true, the result is returned as a dictionary with the quantifier aliases as keys.
:return: If `return_alias_dict=True`, a dict with all quantifier aliases as keys
and (predictions, uncertainties_or_confidences) as values.
Otherwise (default), a tuple (predictions, uncertainties_or_confidences) if a single quantifier was
passed as string or instance, or a collection of such tuples if the passed quantifiers was an iterable.
"""
if verbose > 0:
warnings.warn("Verbosity not yet supported in lazy ensemble models.")
Expand All @@ -501,6 +505,7 @@ def predict_quantified(
num_processes=num_processes,
context=context,
models=models,
return_alias_dict=return_alias_dict,
)

def quantify_predictions(
Expand All @@ -511,6 +516,7 @@ def quantify_predictions(
num_processes: int = None,
context: Callable[[int], EnsembleContextManager] = None,
models: Optional[Iterable[int]] = None,
return_alias_dict: bool = False,
):
"""
A utility function to make predictions on all atomic models and then infer overall predictions and uncertainty
Expand All @@ -525,7 +531,11 @@ def quantify_predictions(
:param num_processes: The number of processes to use. Default: The default or value specified when creating the lazy ensemble.
:param context: A contextmanager which prepares a newly crated process for execution (e.g. by configuring the gpus). See class docstring for explanation of default values.
:param models: A list of model indices to use for prediction. Default: `None`(All models).
:return: A tuple (predictions, uncertainties_or_confidences) if a single quantifier was passed as string or instance, or a collection of such tuples if the passed quantifiers was an iterable.
:param return_alias_dict: If true, the result is returned as a dictionary with the quantifier aliases as keys.
:return: If `return_alias_dict=True`, a dict with all quantifier aliases as keys
and (predictions, uncertainties_or_confidences) as values.
Otherwise (default), a tuple (predictions, uncertainties_or_confidences) if a single quantifier was
passed as string or instance, or a collection of such tuples if the passed quantifiers was an iterable.
"""
all_q, pp_q, sample_q, return_single_tuple = self._quantifiers_as_list(
quantifier
Expand All @@ -549,14 +559,19 @@ def quantify_predictions(
scores = np.empty(scores_shape)
scores[:, i] = predictions

results = []
results = dict() if return_alias_dict else list()
for q in all_q:
predictions, superv_scores = q.calculate(scores)
superv_scores = q.cast_conf_or_unc(
as_confidence=as_confidence, superv_scores=superv_scores
)
results.append((predictions, superv_scores))
if return_single_tuple:
# Add the predictions and superv_scores to the results
if return_alias_dict:
for alias in q.aliases():
results[alias] = (predictions, superv_scores)
else:
results.append((predictions, superv_scores))
if return_single_tuple and not return_alias_dict:
return results[0]
return results

Expand Down

0 comments on commit b48c4bd

Please sign in to comment.