Skip to content

Commit

Permalink
Clean up tf.debugging.asserts
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Oct 11, 2024
1 parent dbb7d46 commit 8f2edc8
Show file tree
Hide file tree
Showing 12 changed files with 64 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ class DumbTrajectorySampler(RandomFourierFeatureTrajectorySampler):
"""A RandomFourierFeatureTrajectorySampler that doesn't update trajectories in place."""

def update_trajectory(self, trajectory: TrajectoryFunction) -> TrajectoryFunction:
tf.debugging.Assert(
isinstance(trajectory, feature_decomposition_trajectory), [tf.constant([])]
)
tf.debugging.Assert(isinstance(trajectory, feature_decomposition_trajectory), [])
return self.get_trajectory()


Expand Down
12 changes: 6 additions & 6 deletions trieste/acquisition/function/entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def prepare_acquisition_function(
:exc:`~tf.errors.InvalidArgumentError` if used with a batch size greater than one.
:raise tf.errors.InvalidArgumentError: If ``dataset`` is empty.
"""
tf.debugging.Assert(dataset is not None, [tf.constant([])])
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")

Expand All @@ -150,10 +150,10 @@ def update_acquisition_function(
:param model: The model.
:param dataset: The data from the observer.
"""
tf.debugging.Assert(dataset is not None, [tf.constant([])])
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")
tf.debugging.Assert(isinstance(function, min_value_entropy_search), [tf.constant([])])
tf.debugging.Assert(isinstance(function, min_value_entropy_search), [])

query_points = self._search_space.sample(num_samples=self._grid_size)
tf.debugging.assert_same_float_dtype([dataset.query_points, query_points])
Expand Down Expand Up @@ -334,7 +334,7 @@ def prepare_acquisition_function(
f"covariance_between_points and get_observation_noise; received {model!r}"
)

tf.debugging.Assert(dataset is not None, [tf.constant([])])
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")

Expand Down Expand Up @@ -363,10 +363,10 @@ def update_acquisition_function(
for the current step. Defaults to ``True``.
:return: The updated acquisition function.
"""
tf.debugging.Assert(dataset is not None, [tf.constant([])])
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")
tf.debugging.Assert(self._quality_term is not None, [tf.constant([])])
tf.debugging.Assert(self._quality_term is not None, [])

if new_optimization_step:
self._update_quality_term(dataset, model)
Expand Down
52 changes: 21 additions & 31 deletions trieste/acquisition/function/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def update_acquisition_function(
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")
tf.debugging.Assert(isinstance(function, probability_below_threshold), [tf.constant([])])
tf.debugging.Assert(isinstance(function, probability_below_threshold), [])
mean, _ = model.predict(dataset.query_points)
eta = tf.reduce_min(mean, axis=0)[0]
function.update(eta) # type: ignore
Expand Down Expand Up @@ -127,7 +127,7 @@ def prepare_acquisition_function(
greater than one.
:raise tf.errors.InvalidArgumentError: If ``dataset`` is empty.
"""
tf.debugging.Assert(dataset is not None, [tf.constant([])])
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")

Expand Down Expand Up @@ -161,10 +161,10 @@ def update_acquisition_function(
:param model: The model.
:param dataset: The data from the observer. Must be populated.
"""
tf.debugging.Assert(dataset is not None, [tf.constant([])])
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")
tf.debugging.Assert(isinstance(function, expected_improvement), [tf.constant([])])
tf.debugging.Assert(isinstance(function, expected_improvement), [])

# Check feasibility against any explicit constraints in the search space.
if self._search_space is not None and self._search_space.has_constraints:
Expand Down Expand Up @@ -251,7 +251,7 @@ def prepare_acquisition_function(
f"AugmentedExpectedImprovement only works with models that support "
f"get_observation_noise; received {model!r}"
)
tf.debugging.Assert(dataset is not None, [tf.constant([])])
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")
mean, _ = model.predict(dataset.query_points)
Expand All @@ -269,10 +269,10 @@ def update_acquisition_function(
:param model: The model.
:param dataset: The data from the observer. Must be populated.
"""
tf.debugging.Assert(dataset is not None, [tf.constant([])])
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")
tf.debugging.Assert(isinstance(function, augmented_expected_improvement), [tf.constant([])])
tf.debugging.Assert(isinstance(function, augmented_expected_improvement), [])
mean, _ = model.predict(dataset.query_points)
eta = tf.reduce_min(mean, axis=0)
function.update(eta) # type: ignore
Expand Down Expand Up @@ -669,7 +669,7 @@ def prepare_acquisition_function(
:raise KeyError: If `objective_tag` is not found in ``datasets`` and ``models``.
:raise tf.errors.InvalidArgumentError: If the objective data is empty.
"""
tf.debugging.Assert(datasets is not None, [tf.constant([])])
tf.debugging.Assert(datasets is not None, [])
datasets = cast(Mapping[Tag, Dataset], datasets)

objective_model = models[self._objective_tag]
Expand Down Expand Up @@ -719,7 +719,7 @@ def update_acquisition_function(
:param models: The models for each tag.
:param datasets: The data from the observer.
"""
tf.debugging.Assert(datasets is not None, [tf.constant([])])
tf.debugging.Assert(datasets is not None, [])
datasets = cast(Mapping[Tag, Dataset], datasets)

objective_model = models[self._objective_tag]
Expand All @@ -730,7 +730,7 @@ def update_acquisition_function(
message="Expected improvement is defined with respect to existing points in the"
" objective data, but the objective data is empty.",
)
tf.debugging.Assert(self._constraint_fn is not None, [tf.constant([])])
tf.debugging.Assert(self._constraint_fn is not None, [])

constraint_fn = cast(AcquisitionFunction, self._constraint_fn)
self._constraint_builder.update_acquisition_function(
Expand Down Expand Up @@ -777,9 +777,7 @@ def _update_expected_improvement_fn(
if self._expected_improvement_fn is None:
self._expected_improvement_fn = expected_improvement(objective_model, eta)
else:
tf.debugging.Assert(
isinstance(self._expected_improvement_fn, expected_improvement), [tf.constant([])]
)
tf.debugging.Assert(isinstance(self._expected_improvement_fn, expected_improvement), [])
self._expected_improvement_fn.update(eta) # type: ignore


Expand Down Expand Up @@ -830,7 +828,7 @@ def prepare_acquisition_function(

sampler = model.reparam_sampler(self._sample_size)

tf.debugging.Assert(dataset is not None, [tf.constant([])])
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")

Expand Down Expand Up @@ -858,12 +856,10 @@ def update_acquisition_function(
:param model: The model. Must have output dimension [1]. Unused here.
:param dataset: The data from the observer. Cannot be empty
"""
tf.debugging.Assert(dataset is not None, [tf.constant([])])
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")
tf.debugging.Assert(
isinstance(function, monte_carlo_expected_improvement), [tf.constant([])]
)
tf.debugging.Assert(isinstance(function, monte_carlo_expected_improvement), [])
sampler = function._sampler # type: ignore
sampler.reset_sampler()
samples_at_query_points = sampler.sample(
Expand Down Expand Up @@ -974,7 +970,7 @@ def prepare_acquisition_function(

sampler = model.reparam_sampler(self._sample_size)

tf.debugging.Assert(dataset is not None, [tf.constant([])])
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")

Expand Down Expand Up @@ -1002,12 +998,10 @@ def update_acquisition_function(
:param model: The model. Must have output dimension [1]. Unused here
:param dataset: The data from the observer. Cannot be empty.
"""
tf.debugging.Assert(dataset is not None, [tf.constant([])])
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")
tf.debugging.Assert(
isinstance(function, monte_carlo_augmented_expected_improvement), [tf.constant([])]
)
tf.debugging.Assert(isinstance(function, monte_carlo_augmented_expected_improvement), [])
sampler = function._sampler # type: ignore
sampler.reset_sampler()
samples_at_query_points = sampler.sample(
Expand Down Expand Up @@ -1111,7 +1105,7 @@ def prepare_acquisition_function(
:raise ValueError (or InvalidArgumentError): If ``dataset`` is not populated, or ``model``
does not have an event shape of [1].
"""
tf.debugging.Assert(dataset is not None, [tf.constant([])])
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")

Expand All @@ -1135,12 +1129,10 @@ def update_acquisition_function(
:param model: The model. Must have event shape [1].
:param dataset: The data from the observer. Must be populated.
"""
tf.debugging.Assert(dataset is not None, [tf.constant([])])
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")
tf.debugging.Assert(
isinstance(function, batch_monte_carlo_expected_improvement), [tf.constant([])]
)
tf.debugging.Assert(isinstance(function, batch_monte_carlo_expected_improvement), [])
mean, _ = model.predict(dataset.query_points)
eta = tf.reduce_min(mean, axis=0)
function.update(eta) # type: ignore
Expand Down Expand Up @@ -1848,9 +1840,7 @@ def update_acquisition_function(
:param model: The model.
:param dataset: Unused.
"""
tf.debugging.Assert(
isinstance(function, multiple_optimism_lower_confidence_bound), [tf.constant([])]
)
tf.debugging.Assert(isinstance(function, multiple_optimism_lower_confidence_bound), [])
return function # nothing to update


Expand Down
8 changes: 4 additions & 4 deletions trieste/acquisition/function/greedy_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def prepare_acquisition_function(
:return: The (log) expected improvement penalized with respect to the pending points.
:raise tf.errors.InvalidArgumentError: If the ``dataset`` is empty.
"""
tf.debugging.Assert(dataset is not None, [tf.constant([])])
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")

Expand Down Expand Up @@ -164,10 +164,10 @@ def update_acquisition_function(
for the current step. Defaults to ``True``.
:return: The updated acquisition function.
"""
tf.debugging.Assert(dataset is not None, [tf.constant([])])
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")
tf.debugging.Assert(self._base_acquisition_function is not None, [tf.constant([])])
tf.debugging.Assert(self._base_acquisition_function is not None, [])

if new_optimization_step:
self._update_base_acquisition_function(dataset, model)
Expand Down Expand Up @@ -447,7 +447,7 @@ def __init__(
See class docs for more details.
:raise tf.errors.InvalidArgumentError: If ``fantasize_method`` is not "KB" or "sample".
"""
tf.debugging.Assert(fantasize_method in ["KB", "sample"], [tf.constant([])])
tf.debugging.Assert(fantasize_method in ["KB", "sample"], [])

if base_acquisition_function_builder is None:
base_acquisition_function_builder = ExpectedImprovement()
Expand Down
26 changes: 11 additions & 15 deletions trieste/acquisition/function/multi_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def prepare_acquisition_function(
:param dataset: The data from the observer. Must be populated.
:return: The expected hypervolume improvement acquisition function.
"""
tf.debugging.Assert(dataset is not None, [tf.constant([])])
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")
mean, _ = model.predict(dataset.query_points)
Expand Down Expand Up @@ -121,10 +121,10 @@ def update_acquisition_function(
:param model: The model.
:param dataset: The data from the observer. Must be populated.
"""
tf.debugging.Assert(dataset is not None, [tf.constant([])])
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")
tf.debugging.Assert(isinstance(function, expected_hv_improvement), [tf.constant([])])
tf.debugging.Assert(isinstance(function, expected_hv_improvement), [])
mean, _ = model.predict(dataset.query_points)

if callable(self._ref_point_spec):
Expand Down Expand Up @@ -319,7 +319,7 @@ def prepare_acquisition_function(
:param dataset: The data from the observer. Must be populated.
:return: The batch expected hypervolume improvement acquisition function.
"""
tf.debugging.Assert(dataset is not None, [tf.constant([])])
tf.debugging.Assert(dataset is not None, [])
dataset = cast(Dataset, dataset)
tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.")
mean, _ = model.predict(dataset.query_points)
Expand Down Expand Up @@ -564,9 +564,9 @@ def prepare_acquisition_function(
:return: The HIPPO acquisition function.
:raise tf.errors.InvalidArgumentError: If the ``dataset`` is empty.
"""
tf.debugging.Assert(datasets is not None, [tf.constant([])])
tf.debugging.Assert(datasets is not None, [])
datasets = cast(Mapping[Tag, Dataset], datasets)
tf.debugging.Assert(datasets[self._objective_tag] is not None, [tf.constant([])])
tf.debugging.Assert(datasets[self._objective_tag] is not None, [])
tf.debugging.assert_positive(
len(datasets[self._objective_tag]),
message=f"{self._objective_tag} dataset must be populated.",
Expand Down Expand Up @@ -599,14 +599,14 @@ def update_acquisition_function(
for the current step. Defaults to ``True``.
:return: The updated acquisition function.
"""
tf.debugging.Assert(datasets is not None, [tf.constant([])])
tf.debugging.Assert(datasets is not None, [])
datasets = cast(Mapping[Tag, Dataset], datasets)
tf.debugging.Assert(datasets[self._objective_tag] is not None, [tf.constant([])])
tf.debugging.Assert(datasets[self._objective_tag] is not None, [])
tf.debugging.assert_positive(
len(datasets[self._objective_tag]),
message=f"{self._objective_tag} dataset must be populated.",
)
tf.debugging.Assert(self._base_acquisition_function is not None, [tf.constant([])])
tf.debugging.Assert(self._base_acquisition_function is not None, [])

if new_optimization_step:
self._update_base_acquisition_function(models, datasets)
Expand Down Expand Up @@ -689,9 +689,7 @@ def __init__(self, model: ProbabilisticModel, pending_points: TensorType):
:return: The penalization function. This function will raise
:exc:`ValueError` or :exc:`~tf.errors.InvalidArgumentError` if used with a batch size
greater than one."""
tf.debugging.Assert(
pending_points is not None and len(pending_points) != 0, [tf.constant([])]
)
tf.debugging.Assert(pending_points is not None and len(pending_points) != 0, [])

self._model = model
self._pending_points = tf.Variable(pending_points, shape=[None, *pending_points.shape[1:]])
Expand All @@ -701,9 +699,7 @@ def __init__(self, model: ProbabilisticModel, pending_points: TensorType):

def update(self, pending_points: TensorType) -> None:
"""Update the penalizer with new pending points."""
tf.debugging.Assert(
pending_points is not None and len(pending_points) != 0, [tf.constant([])]
)
tf.debugging.Assert(pending_points is not None and len(pending_points) != 0, [])

self._pending_points.assign(pending_points)
pending_means, pending_vars = self._model.predict(self._pending_points)
Expand Down
2 changes: 1 addition & 1 deletion trieste/acquisition/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def get_bounds_of_box_relaxation_around_point(
:param current_point: The point at which to make the continuous relaxation.
:return: Bounds for the Scipy optimizer.
"""
tf.debugging.Assert(isinstance(space, TaggedProductSearchSpace), [tf.constant([])])
tf.debugging.Assert(isinstance(space, TaggedProductSearchSpace), [])

space_with_fixed_discrete = space
for tag in space.subspace_tags:
Expand Down
2 changes: 1 addition & 1 deletion trieste/acquisition/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ def acquire(
def state_func(
state: AsynchronousRuleState | None,
) -> tuple[AsynchronousRuleState | None, TensorType]:
tf.debugging.Assert(self._acquisition_function is not None, [tf.constant([])])
tf.debugging.Assert(self._acquisition_function is not None, [])

if state is None:
state = AsynchronousRuleState(None)
Expand Down
Loading

0 comments on commit 8f2edc8

Please sign in to comment.