Skip to content

Commit

Permalink
Merge branch 'dev' into defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed Feb 11, 2025
2 parents 317f444 + 8210610 commit 24ea269
Show file tree
Hide file tree
Showing 24 changed files with 1,098 additions and 1,003 deletions.
7 changes: 3 additions & 4 deletions bayesflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
workflows,
utils,
)

from .workflows import BasicWorkflow
from .approximators import ContinuousApproximator
from .adapters import Adapter
from .approximators import ContinuousApproximator
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
from .simulators import make_simulator
from .workflows import BasicWorkflow


def setup():
Expand All @@ -38,7 +37,7 @@ def setup():

from bayesflow.utils import logging

logging.info(f"Using backend {keras.backend.backend()!r}")
logging.debug(f"Using backend {keras.backend.backend()!r}")


# call and clean up namespace
Expand Down
7 changes: 7 additions & 0 deletions bayesflow/adapters/transforms/as_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,10 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.squeeze(data, axis=2)

return data

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "AsSet":
return cls()

def get_config(self) -> dict:
return {}
7 changes: 7 additions & 0 deletions bayesflow/adapters/transforms/as_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,10 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.squeeze(data, axis=2)

return data

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "AsTimeSeries":
return cls()

def get_config(self) -> dict:
return {}
5 changes: 2 additions & 3 deletions bayesflow/diagnostics/plots/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def loss(
val_step_index = val_step_index[: val_losses.shape[0]]

# Loop through loss entries and populate plot
looper = [axes] if num_row == 1 else axes.flat
for i, ax in enumerate(looper):
for i, ax in enumerate(axes.flat):
# Plot train curve
ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training")
if moving_average and train_losses.columns[i] == "Loss":
Expand Down Expand Up @@ -127,7 +126,7 @@ def loss(

# Add labels, titles, and set font sizes
add_titles_and_labels(
axes=np.atleast_1d(axes),
axes=axes,
num_row=num_row,
num_col=1,
title=["Loss Trajectory"],
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/diagnostics/plots/recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def recovery(
if uncertainty_agg is not None:
u = uncertainty_agg(targets, axis=1)

for i, ax in enumerate(np.atleast_1d(plot_data["axes"].flat)):
for i, ax in enumerate(plot_data["axes"].flat):
if i >= plot_data["num_variables"]:
break

Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
)


from ..inference_network import InferenceNetwork
from ..embeddings import FourierEmbedding
from bayesflow.networks import InferenceNetwork
from bayesflow.networks.embeddings import FourierEmbedding


@register_keras_serializable(package="bayesflow.networks")
class ContinuousConsistencyModel(InferenceNetwork):
class ContinuousTimeConsistencyModel(InferenceNetwork):
"""Implements an sCM (simple, stable, and scalable Consistency Model)
with continous-time Consistency Training (CT) as described in [1].
The sampling procedure is taken from [2].
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .cif import CIF
from .consistency_models import ConsistencyModel, ContinuousConsistencyModel
from .consistency_models import ConsistencyModel
from .coupling_flow import CouplingFlow
from .deep_set import DeepSet
from .flow_matching import FlowMatching
Expand Down
1 change: 0 additions & 1 deletion bayesflow/networks/consistency_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .consistency_model import ConsistencyModel
from .continuous_consistency_model import ContinuousConsistencyModel
150 changes: 113 additions & 37 deletions bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from collections.abc import Sequence

import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Shape, Tensor
from bayesflow.utils import (
expand_right_as,
find_network,
integrate,
jacobian_trace,
keras_kwargs,
optimal_transport,
serialize_value_or_type,
deserialize_value_or_type,
)
from ..inference_network import InferenceNetwork
from .integrators import EulerIntegrator
from .integrators import RK2Integrator
from .integrators import RK4Integrator


@serializable(package="bayesflow.networks")
Expand Down Expand Up @@ -47,48 +48,71 @@ def __init__(
self,
subnet: str | type = "mlp",
base_distribution: str = "normal",
integrator: str = "euler",
use_optimal_transport: bool = False,
loss_fn: str = "mse",
integrate_kwargs: dict[str, any] = None,
optimal_transport_kwargs: dict[str, any] = None,
**kwargs,
):
super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs))

self.use_optimal_transport = use_optimal_transport
self.optimal_transport_kwargs = self.OPTIMAL_TRANSPORT_DEFAULT_CONFIG.copy()
self.optimal_transport_kwargs.update(optimal_transport_kwargs or {})

if subnet == "mlp":
subnet_kwargs = self.MLP_DEFAULT_CONFIG.copy()
subnet_kwargs.update(kwargs.get("subnet_kwargs", {}))
if integrate_kwargs is None:
integrate_kwargs = {
"method": "rk45",
"steps": "adaptive",
"tolerance": 1e-3,
"min_steps": 10,
"max_steps": 100,
}

self.integrate_kwargs = integrate_kwargs

if optimal_transport_kwargs is None:
optimal_transport_kwargs = {
"method": "sinkhorn",
"cost": "euclidean",
"regularization": 0.1,
"max_steps": 100,
"tolerance": 1e-4,
}

self.loss_fn = keras.losses.get(loss_fn)

# TODO - Spawn subnet here
self.optimal_transport_kwargs = optimal_transport_kwargs

self.seed_generator = keras.random.SeedGenerator()

match integrator:
case "euler":
self.integrator = EulerIntegrator(subnet, **kwargs)
case "rk2":
self.integrator = RK2Integrator(subnet, **kwargs)
case "rk4":
self.integrator = RK4Integrator(subnet, **kwargs)
case _:
raise NotImplementedError(f"No support for {integrator} integration")
self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {}))
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros")

# serialization: store all parameters necessary to call __init__
self.config = {
"base_distribution": base_distribution,
"integrator": integrator,
"use_optimal_transport": use_optimal_transport,
"optimal_transport_kwargs": optimal_transport_kwargs,
"integrate_kwargs": integrate_kwargs,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "subnet", subnet)

def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
super().build(xz_shape)
self.integrator.build(xz_shape, conditions_shape)
super().build(xz_shape, conditions_shape=conditions_shape)

self.output_projector.units = xz_shape[-1]
input_shape = list(xz_shape)

# construct time vector
input_shape[-1] += 1
if conditions_shape is not None:
input_shape[-1] += conditions_shape[-1]

input_shape = tuple(input_shape)

self.subnet.build(input_shape)
out_shape = self.subnet.compute_output_shape(input_shape)
self.output_projector.build(out_shape)

def get_config(self):
base_config = super().get_config()
Expand All @@ -99,32 +123,80 @@ def from_config(cls, config):
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

def velocity(self, xz: Tensor, t: float | Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
t = keras.ops.convert_to_tensor(t)
t = expand_right_as(t, xz)
t = keras.ops.broadcast_to(t, keras.ops.shape(xz)[:-1] + (1,))

if conditions is None:
xtc = keras.ops.concatenate([xz, t], axis=-1)
else:
xtc = keras.ops.concatenate([xz, t, conditions], axis=-1)

return self.output_projector(self.subnet(xtc, training=training), training=training)

def _velocity_trace(
self, xz: Tensor, t: Tensor, conditions: Tensor = None, max_steps: int = None, training: bool = False
) -> (Tensor, Tensor):
def f(x):
return self.velocity(x, t, conditions=conditions, training=training)

v, trace = jacobian_trace(f, xz, max_steps=max_steps, seed=self.seed_generator, return_output=True)

return v, keras.ops.expand_dims(trace, axis=-1)

def _forward(
self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
steps = kwargs.get("steps", 100)

if density:
z, trace = self.integrator(x, conditions=conditions, steps=steps, density=True)
log_prob = self.base_distribution.log_prob(z)
log_density = log_prob + trace

def deltas(t, xz):
v, trace = self._velocity_trace(xz, t, conditions=conditions, training=training)
return {"xz": v, "trace": trace}

state = {"xz": x, "trace": keras.ops.zeros(keras.ops.shape(x)[:-1] + (1,), dtype=keras.ops.dtype(x))}
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs))

z = state["xz"]
log_density = self.base_distribution.log_prob(z) + keras.ops.squeeze(state["trace"], axis=-1)

return z, log_density

z = self.integrator(x, conditions=conditions, steps=steps, density=False)
def deltas(t, xz):
return {"xz": self.velocity(xz, t, conditions=conditions, training=training)}

state = {"xz": x}
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs))

z = state["xz"]

return z

def _inverse(
self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
steps = kwargs.get("steps", 100)

if density:
x, trace = self.integrator(z, conditions=conditions, steps=steps, density=True, inverse=True)
log_prob = self.base_distribution.log_prob(z)
log_density = log_prob - trace

def deltas(t, xz):
v, trace = self._velocity_trace(xz, t, conditions=conditions, training=training)
return {"xz": v, "trace": trace}

state = {"xz": z, "trace": keras.ops.zeros(keras.ops.shape(z)[:-1] + (1,), dtype=keras.ops.dtype(z))}
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs))

x = state["xz"]
log_density = self.base_distribution.log_prob(z) - keras.ops.squeeze(state["trace"], axis=-1)

return x, log_density

x = self.integrator(z, conditions=conditions, steps=steps, density=False, inverse=True)
def deltas(t, xz):
return {"xz": self.velocity(xz, t, conditions=conditions, training=training)}

state = {"xz": z}
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs))

x = state["xz"]

return x

def compute_metrics(
Expand All @@ -136,7 +208,11 @@ def compute_metrics(
else:
# not pre-configured, resample
x1 = x
x0 = keras.random.normal(keras.ops.shape(x1), dtype=keras.ops.dtype(x1), seed=self.seed_generator)
if not self.built:
xz_shape = keras.ops.shape(x1)
conditions_shape = None if conditions is None else keras.ops.shape(conditions)
self.build(xz_shape, conditions_shape)
x0 = self.base_distribution.sample(keras.ops.shape(x1)[:-1])

if self.use_optimal_transport:
x1, x0, conditions = optimal_transport(
Expand All @@ -151,9 +227,9 @@ def compute_metrics(

base_metrics = super().compute_metrics(x1, conditions, stage)

predicted_velocity = self.integrator.velocity(x, t, conditions)
predicted_velocity = self.velocity(x, t, conditions, training=stage == "training")

loss = keras.losses.mean_squared_error(target_velocity, predicted_velocity)
loss = self.loss_fn(target_velocity, predicted_velocity)
loss = keras.ops.mean(loss)

return base_metrics | {"loss": loss}
3 changes: 0 additions & 3 deletions bayesflow/networks/flow_matching/integrators/__init__.py

This file was deleted.

Loading

0 comments on commit 24ea269

Please sign in to comment.