Skip to content

Commit

Permalink
feat: allow use of functional API by sanitizing build input shapes (#332
Browse files Browse the repository at this point in the history
)
  • Loading branch information
vpratz authored Feb 21, 2025
1 parent ec2a0e2 commit 07ec251
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 0 deletions.
2 changes: 2 additions & 0 deletions bayesflow/networks/deep_set/deep_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from bayesflow.types import Tensor
from bayesflow.utils import filter_kwargs
from bayesflow.utils.decorators import sanitize_input_shape

from .equivariant_module import EquivariantModule
from .invariant_module import InvariantModule
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(
self.output_projector = keras.layers.Dense(summary_dim, activation="linear")
self.summary_dim = summary_dim

@sanitize_input_shape
def build(self, input_shape):
super().build(input_shape)
self.call(keras.ops.zeros(input_shape))
Expand Down
2 changes: 2 additions & 0 deletions bayesflow/networks/deep_set/equivariant_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils.decorators import sanitize_input_shape
from .invariant_module import InvariantModule


Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(

self.layer_norm = layers.LayerNormalization() if layer_norm else None

@sanitize_input_shape
def build(self, input_shape):
self.call(keras.ops.zeros(input_shape))

Expand Down
2 changes: 2 additions & 0 deletions bayesflow/networks/deep_set/invariant_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from bayesflow.types import Tensor
from bayesflow.utils import find_pooling
from bayesflow.utils.decorators import sanitize_input_shape


@serializable(package="bayesflow.networks")
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(

self.pooling_layer = find_pooling(pooling, **pooling_kwargs)

@sanitize_input_shape
def build(self, input_shape):
self.call(keras.ops.zeros(input_shape))

Expand Down
2 changes: 2 additions & 0 deletions bayesflow/networks/lstnet/lstnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils.decorators import sanitize_input_shape
from .skip_recurrent import SkipRecurrentNet
from ..summary_network import SummaryNetwork

Expand Down Expand Up @@ -78,6 +79,7 @@ def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
x = self.output_projector(x)
return x

@sanitize_input_shape
def build(self, input_shape):
super().build(input_shape)
self.call(keras.ops.zeros(input_shape))
2 changes: 2 additions & 0 deletions bayesflow/networks/lstnet/skip_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from bayesflow.types import Tensor
from bayesflow.utils import keras_kwargs, find_recurrent_net
from bayesflow.utils.decorators import sanitize_input_shape


@serializable(package="bayesflow.networks")
Expand Down Expand Up @@ -58,5 +59,6 @@ def call(self, time_series: Tensor, training: bool = False, **kwargs) -> Tensor:
skip_summary = self.skip_recurrent(self.skip_conv(time_series), training=training)
return keras.ops.concatenate((direct_summary, skip_summary), axis=-1)

@sanitize_input_shape
def build(self, input_shape):
self.call(keras.ops.zeros(input_shape))
2 changes: 2 additions & 0 deletions bayesflow/networks/summary_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from bayesflow.metrics.functional import maximum_mean_discrepancy
from bayesflow.types import Tensor
from bayesflow.utils import find_distribution, keras_kwargs
from bayesflow.utils.decorators import sanitize_input_shape


class SummaryNetwork(keras.Layer):
def __init__(self, base_distribution: str = None, **kwargs):
super().__init__(**keras_kwargs(kwargs))
self.base_distribution = find_distribution(base_distribution)

@sanitize_input_shape
def build(self, input_shape):
if self.base_distribution is not None:
output_shape = keras.ops.shape(self.call(keras.ops.zeros(input_shape)))
Expand Down
18 changes: 18 additions & 0 deletions bayesflow/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import wraps
import inspect
from typing import overload, TypeVar
from bayesflow.types import Shape

Fn = TypeVar("Fn", bound=Callable[..., any])

Expand Down Expand Up @@ -110,3 +111,20 @@ def callback(x):
fn = alias("batch_shape", "batch_size")(fn)

return fn


def sanitize_input_shape(fn: Callable):
"""Decorator to replace the first dimension in input_shape with a dummy batch size if it is None"""

# The Keras functional API passes input_shape = (None, second_dim, third_dim, ...), which
# causes problems when constructions like self.call(keras.ops.zeros(input_shape)) are used
# in build. To alleviate those problems, this decorator replaces None with an arbitrary batch size.
def callback(input_shape: Shape) -> Shape:
if input_shape[0] is None:
input_shape = list(input_shape)
input_shape[0] = 32
return tuple(input_shape)
return input_shape

fn = argument_callback("input_shape", callback)(fn)
return fn

0 comments on commit 07ec251

Please sign in to comment.