From 07ec2519d6bb8e4618213a0ebb969d35edb58a0b Mon Sep 17 00:00:00 2001 From: Valentin Pratz <112951103+vpratz@users.noreply.github.com> Date: Fri, 21 Feb 2025 11:22:30 +0100 Subject: [PATCH] feat: allow use of functional API by sanitizing build input shapes (#332) --- bayesflow/networks/deep_set/deep_set.py | 2 ++ .../networks/deep_set/equivariant_module.py | 2 ++ .../networks/deep_set/invariant_module.py | 2 ++ bayesflow/networks/lstnet/lstnet.py | 2 ++ bayesflow/networks/lstnet/skip_recurrent.py | 2 ++ bayesflow/networks/summary_network.py | 2 ++ bayesflow/utils/decorators.py | 18 ++++++++++++++++++ 7 files changed, 30 insertions(+) diff --git a/bayesflow/networks/deep_set/deep_set.py b/bayesflow/networks/deep_set/deep_set.py index bd7fde770..953f90aa7 100644 --- a/bayesflow/networks/deep_set/deep_set.py +++ b/bayesflow/networks/deep_set/deep_set.py @@ -5,6 +5,7 @@ from bayesflow.types import Tensor from bayesflow.utils import filter_kwargs +from bayesflow.utils.decorators import sanitize_input_shape from .equivariant_module import EquivariantModule from .invariant_module import InvariantModule @@ -78,6 +79,7 @@ def __init__( self.output_projector = keras.layers.Dense(summary_dim, activation="linear") self.summary_dim = summary_dim + @sanitize_input_shape def build(self, input_shape): super().build(input_shape) self.call(keras.ops.zeros(input_shape)) diff --git a/bayesflow/networks/deep_set/equivariant_module.py b/bayesflow/networks/deep_set/equivariant_module.py index 5ecd8ba6e..ab933540c 100644 --- a/bayesflow/networks/deep_set/equivariant_module.py +++ b/bayesflow/networks/deep_set/equivariant_module.py @@ -5,6 +5,7 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor +from bayesflow.utils.decorators import sanitize_input_shape from .invariant_module import InvariantModule @@ -66,6 +67,7 @@ def __init__( self.layer_norm = layers.LayerNormalization() if layer_norm else None + @sanitize_input_shape def build(self, input_shape): self.call(keras.ops.zeros(input_shape)) diff --git a/bayesflow/networks/deep_set/invariant_module.py b/bayesflow/networks/deep_set/invariant_module.py index 02f4052aa..b29a7cb5c 100644 --- a/bayesflow/networks/deep_set/invariant_module.py +++ b/bayesflow/networks/deep_set/invariant_module.py @@ -6,6 +6,7 @@ from bayesflow.types import Tensor from bayesflow.utils import find_pooling +from bayesflow.utils.decorators import sanitize_input_shape @serializable(package="bayesflow.networks") @@ -76,6 +77,7 @@ def __init__( self.pooling_layer = find_pooling(pooling, **pooling_kwargs) + @sanitize_input_shape def build(self, input_shape): self.call(keras.ops.zeros(input_shape)) diff --git a/bayesflow/networks/lstnet/lstnet.py b/bayesflow/networks/lstnet/lstnet.py index bb2acb9d8..695907ef8 100644 --- a/bayesflow/networks/lstnet/lstnet.py +++ b/bayesflow/networks/lstnet/lstnet.py @@ -2,6 +2,7 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor +from bayesflow.utils.decorators import sanitize_input_shape from .skip_recurrent import SkipRecurrentNet from ..summary_network import SummaryNetwork @@ -78,6 +79,7 @@ def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor: x = self.output_projector(x) return x + @sanitize_input_shape def build(self, input_shape): super().build(input_shape) self.call(keras.ops.zeros(input_shape)) diff --git a/bayesflow/networks/lstnet/skip_recurrent.py b/bayesflow/networks/lstnet/skip_recurrent.py index f25c8ba78..f754de42e 100644 --- a/bayesflow/networks/lstnet/skip_recurrent.py +++ b/bayesflow/networks/lstnet/skip_recurrent.py @@ -3,6 +3,7 @@ from bayesflow.types import Tensor from bayesflow.utils import keras_kwargs, find_recurrent_net +from bayesflow.utils.decorators import sanitize_input_shape @serializable(package="bayesflow.networks") @@ -58,5 +59,6 @@ def call(self, time_series: Tensor, training: bool = False, **kwargs) -> Tensor: skip_summary = self.skip_recurrent(self.skip_conv(time_series), training=training) return keras.ops.concatenate((direct_summary, skip_summary), axis=-1) + @sanitize_input_shape def build(self, input_shape): self.call(keras.ops.zeros(input_shape)) diff --git a/bayesflow/networks/summary_network.py b/bayesflow/networks/summary_network.py index dce40fdfd..872248332 100644 --- a/bayesflow/networks/summary_network.py +++ b/bayesflow/networks/summary_network.py @@ -3,6 +3,7 @@ from bayesflow.metrics.functional import maximum_mean_discrepancy from bayesflow.types import Tensor from bayesflow.utils import find_distribution, keras_kwargs +from bayesflow.utils.decorators import sanitize_input_shape class SummaryNetwork(keras.Layer): @@ -10,6 +11,7 @@ def __init__(self, base_distribution: str = None, **kwargs): super().__init__(**keras_kwargs(kwargs)) self.base_distribution = find_distribution(base_distribution) + @sanitize_input_shape def build(self, input_shape): if self.base_distribution is not None: output_shape = keras.ops.shape(self.call(keras.ops.zeros(input_shape))) diff --git a/bayesflow/utils/decorators.py b/bayesflow/utils/decorators.py index 000bc6837..e8c89aa7b 100644 --- a/bayesflow/utils/decorators.py +++ b/bayesflow/utils/decorators.py @@ -2,6 +2,7 @@ from functools import wraps import inspect from typing import overload, TypeVar +from bayesflow.types import Shape Fn = TypeVar("Fn", bound=Callable[..., any]) @@ -110,3 +111,20 @@ def callback(x): fn = alias("batch_shape", "batch_size")(fn) return fn + + +def sanitize_input_shape(fn: Callable): + """Decorator to replace the first dimension in input_shape with a dummy batch size if it is None""" + + # The Keras functional API passes input_shape = (None, second_dim, third_dim, ...), which + # causes problems when constructions like self.call(keras.ops.zeros(input_shape)) are used + # in build. To alleviate those problems, this decorator replaces None with an arbitrary batch size. + def callback(input_shape: Shape) -> Shape: + if input_shape[0] is None: + input_shape = list(input_shape) + input_shape[0] = 32 + return tuple(input_shape) + return input_shape + + fn = argument_callback("input_shape", callback)(fn) + return fn