Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Categorical search space tweaks #869

Merged
merged 4 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions tests/unit/test_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
SearchSpace,
TaggedMultiSearchSpace,
TaggedProductSearchSpace,
cast_encoder,
one_hot_encoder,
)
from trieste.types import TensorType
Expand Down Expand Up @@ -1759,6 +1760,11 @@ def test_categorical_search_space__to_tags_raises_for_non_integers() -> None:
tf.constant([[0], [0]], dtype=tf.float64),
tf.constant([[1], [1]], dtype=tf.float64),
),
(
CategoricalSearchSpace(["Y", "N"]),
tf.constant([[0], [1], [0]], dtype=tf.float64),
tf.constant([[0], [1], [0]], dtype=tf.float64),
),
(
CategoricalSearchSpace(["R", "G", "B"], dtype=tf.float32),
tf.constant([[0], [2], [1]], dtype=tf.float32),
Expand All @@ -1777,13 +1783,13 @@ def test_categorical_search_space__to_tags_raises_for_non_integers() -> None:
(
CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]),
tf.constant([[0, 0], [2, 0], [1, 1]], dtype=tf.float64),
tf.constant([[1, 0, 0, 1, 0], [0, 0, 1, 1, 0], [0, 1, 0, 0, 1]], dtype=tf.float64),
tf.constant([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 1]], dtype=tf.float64),
),
(
CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]),
tf.constant([[[0, 0], [0, 0]], [[2, 0], [1, 1]]], dtype=tf.float64),
tf.constant(
[[[1, 0, 0, 1, 0], [1, 0, 0, 1, 0]], [[0, 0, 1, 1, 0], [0, 1, 0, 0, 1]]],
[[[1, 0, 0, 0], [1, 0, 0, 0]], [[0, 0, 1, 0], [0, 1, 0, 1]]],
dtype=tf.float64,
),
),
Expand Down Expand Up @@ -1824,6 +1830,12 @@ def test_categorical_search_space_one_hot_encoding(
pytest.param(
CategoricalSearchSpace(["Y", "N"]),
tf.constant([[0], [2], [1]]),
ValueError,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps check error message at line 660?

id="Out of range binary input value",
),
pytest.param(
CategoricalSearchSpace(["Y", "N", "maybe"]),
tf.constant([[0], [3], [1]]),
InvalidArgumentError,
id="Out of range input value",
),
Expand Down Expand Up @@ -1859,3 +1871,19 @@ def test_unbound_search_spaces(
space.lower
with pytest.raises(AttributeError):
space.upper


@pytest.mark.parametrize("input_dtype", [None, tf.float64, tf.float32])
@pytest.mark.parametrize("output_dtype", [None, tf.float64, tf.float32])
def test_cast_encoder(input_dtype: Optional[tf.DType], output_dtype: Optional[tf.DType]) -> None:

query_points = tf.constant([1, 2, 3], dtype=tf.int32)

def add_encoder(x: TensorType) -> TensorType:
assert x.dtype is (input_dtype or tf.int32)
return x + 1

encoder = cast_encoder(add_encoder, input_dtype=input_dtype, output_dtype=output_dtype)
points = encoder(query_points)
assert points.dtype is (output_dtype or input_dtype or tf.int32)
npt.assert_array_equal(tf.cast(query_points + 1, points.dtype), points)
33 changes: 31 additions & 2 deletions trieste/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,24 @@ def one_hot_encoder(space: SearchSpace) -> EncoderFunction:
return space.one_hot_encoder if isinstance(space, HasOneHotEncoder) else lambda x: x


def cast_encoder(
encoder: EncoderFunction,
input_dtype: Optional[tf.DType] = None,
output_dtype: Optional[tf.DType] = None,
) -> EncoderFunction:
"A utility function for casting the input and/or output of an encoder."

def cast_and_encode(x: TensorType) -> TensorType:
if input_dtype is not None:
x = tf.cast(x, input_dtype)
y = encoder(x)
if output_dtype is not None:
y = tf.cast(y, output_dtype)
return y

return cast_and_encode


def one_hot_encoded_space(space: SearchSpace) -> SearchSpace:
"A bounded search space corresponding to the one-hot encoding of the given space."

Expand Down Expand Up @@ -633,7 +651,14 @@ def tags(self) -> Sequence[Sequence[str]]:

@property
def one_hot_encoder(self) -> EncoderFunction:
"""A one-hot encoder for the numerical indices."""
"""A one-hot encoder for the numerical indices. Note that binary categories
are left unchanged instead of adding an unnecessary second feature."""

def binary_encoder(x: TensorType) -> TensorType:
# no need to one-hot encode binary categories (but we should still validate)
if tf.reduce_any((x != 0) & (x != 1)):
raise ValueError(f"Invalid values {tf.boolean_mask(x, ((x != 0) & (x != 1)))}")
return x

def encoder(x: TensorType) -> TensorType:
flat_x, unflatten = flatten_leading_dims(x)
Expand All @@ -644,7 +669,11 @@ def encoder(x: TensorType) -> TensorType:
)
columns = tf.split(flat_x, flat_x.shape[-1], axis=1)
encoders = [
tf_keras.layers.CategoryEncoding(num_tokens=len(ts), output_mode="one_hot")
(
binary_encoder
if len(ts) == 2
else tf_keras.layers.CategoryEncoding(num_tokens=len(ts), output_mode="one_hot")
)
for ts in self.tags
]
encoded = tf.concat(
Expand Down
Loading