Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some love to model comparison #315

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ Check out some of our walk-through notebooks below. We are actively working on p
4. [SBML model using an external simulator](examples/From_ABC_to_BayesFlow.ipynb)
5. [Hyperparameter optimization](examples/Hyperparameter_Optimization.ipynb)
6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
7. More coming soon...
7. [Simple model comparison example (One-Sample T-Test)](examples/One_Sample_TTest.ipynb)
8. More coming soon...

## Documentation \& Help

Expand Down
42 changes: 42 additions & 0 deletions bayesflow/approximators/model_comparison_approximator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Mapping, Sequence

import keras
import numpy as np
from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
Expand Down Expand Up @@ -198,3 +199,44 @@ def get_config(self):
}

return base_config | config

def predict(
self,
*,
conditions: dict[str, np.ndarray],
logits: bool = False,
**kwargs,
) -> np.ndarray:
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)

output = self._predict(**conditions, **kwargs)

if not logits:
output = keras.ops.softmax(output)

output = keras.ops.convert_to_numpy(output)

return output

def _predict(self, classifier_conditions: Tensor = None, summary_variables: Tensor = None, **kwargs) -> Tensor:
if self.summary_network is None:
if summary_variables is not None:
raise ValueError("Cannot use summary variables without a summary network.")
else:
if summary_variables is None:
raise ValueError("Summary variables are required when a summary network is present")

summary_outputs = self.summary_network(
summary_variables, **filter_kwargs(kwargs, self.summary_network.call)
)

if classifier_conditions is None:
classifier_conditions = summary_outputs
else:
classifier_conditions = keras.ops.concatenate([classifier_conditions, summary_outputs], axis=1)

output = self.classifier_network(classifier_conditions)
output = self.logits_projector(output)

return output
4 changes: 2 additions & 2 deletions bayesflow/diagnostics/plots/mc_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ def mc_calibration(

Parameters
----------
true_models : np.ndarray of shape (num_data_sets, num_models)
The one-hot-encoded true model indices per data set.
pred_models : np.ndarray of shape (num_data_sets, num_models)
The predicted posterior model probabilities (PMPs) per data set.
true_models : np.ndarray of shape (num_data_sets, num_models)
The one-hot-encoded true model indices per data set.
model_names : list or None, optional, default: None
The model names for nice plot titles. Inferred if None.
num_bins : int, optional, default: 10
Expand Down
31 changes: 15 additions & 16 deletions bayesflow/diagnostics/plots/mc_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@


def mc_confusion_matrix(
true_models: dict[str, np.ndarray] | np.ndarray,
pred_models: dict[str, np.ndarray] | np.ndarray,
true_models: dict[str, np.ndarray] | np.ndarray,
model_names: Sequence[str] = None,
fig_size: tuple = (5, 5),
label_fontsize: int = 16,
Expand All @@ -23,18 +23,18 @@ def mc_confusion_matrix(
tick_fontsize: int = 12,
xtick_rotation: int = None,
ytick_rotation: int = None,
normalize: bool = True,
normalize: str = None,
cmap: matplotlib.colors.Colormap | str = None,
title: bool = True,
) -> plt.Figure:
"""Plots a confusion matrix for validating a neural network trained for Bayesian model comparison.

Parameters
----------
true_models : np.ndarray of shape (num_data_sets, num_models)
The one-hot-encoded true model indices per data set.
pred_models : np.ndarray of shape (num_data_sets, num_models)
The predicted posterior model probabilities (PMPs) per data set.
true_models : np.ndarray of shape (num_data_sets, num_models)
The one-hot-encoded true model indices per data set.
model_names : list or None, optional, default: None
The model names for nice plot titles. Inferred if None.
fig_size : tuple or None, optional, default: (5, 5)
Expand All @@ -51,9 +51,11 @@ def mc_confusion_matrix(
Rotation of x-axis tick labels (helps with long model names).
ytick_rotation: int, optional, default: None
Rotation of y-axis tick labels (helps with long model names).
normalize : bool, optional, default: True
A flag for normalization of the confusion matrix.
If True, each row of the confusion matrix is normalized to sum to 1.
normalize : {'true', 'pred', 'all'}, default=None
Passed to sklearn.metrics.confusion_matrix.
Normalizes confusion matrix over the true (rows), predicted (columns)
conditions or all the population. If None, confusion matrix will not be
normalized.
cmap : matplotlib.colors.Colormap or str, optional, default: None
Colormap to be used for the cells. If a str, it should be the name of a registered colormap,
e.g., 'viridis'. Default colormap matches the BayesFlow defaults by ranging from white to red.
Expand All @@ -77,29 +79,26 @@ def mc_confusion_matrix(
pred_models = ops.argmax(pred_models, axis=1)

# Compute confusion matrix
cm = confusion_matrix(true_models, pred_models)

# if normalize:
# # Sum along rows and keep dimensions for broadcasting
# cm_sum = ops.sum(cm, axis=1, keepdims=True)
#
# # Broadcast division for normalization
# cm_normalized = cm / cm_sum
cm = confusion_matrix(true_models, pred_models, normalize=normalize)

# Initialize figure
fig, ax = make_figure(1, 1, figsize=fig_size)
ax = ax[0]
im = ax.imshow(cm, interpolation="nearest", cmap=cmap)
cbar = ax.figure.colorbar(im, ax=ax, shrink=0.75)

cbar.ax.tick_params(labelsize=value_fontsize)

ax.set(xticks=ops.arange(cm.shape[1]), yticks=ops.arange(cm.shape[0]))
ax.set_xticks(range(cm.shape[0]))
ax.set_xticklabels(model_names, fontsize=tick_fontsize)
if xtick_rotation:
plt.xticks(rotation=xtick_rotation, ha="right")

ax.set_yticks(range(cm.shape[1]))
ax.set_yticklabels(model_names, fontsize=tick_fontsize)
if ytick_rotation:
plt.yticks(rotation=ytick_rotation)

ax.set_xlabel("Predicted model", fontsize=label_fontsize)
ax.set_ylabel("True model", fontsize=label_fontsize)

Expand Down
44 changes: 28 additions & 16 deletions bayesflow/simulators/model_comparison_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import numpy as np

from bayesflow.types import Shape
from bayesflow.utils import tree_stack
from bayesflow.utils import tree_concatenate
from bayesflow.utils.decorators import allow_batch_size

from bayesflow.utils import numpy_utils as npu

from types import FunctionType

from .simulator import Simulator
from .lambda_simulator import LambdaSimulator


class ModelComparisonSimulator(Simulator):
Expand All @@ -18,10 +21,15 @@ def __init__(
simulators: Sequence[Simulator],
p: Sequence[float] = None,
logits: Sequence[float] = None,
use_mixed_batches: bool = False,
use_mixed_batches: bool = True,
shared_simulator: Simulator | FunctionType = None,
):
self.simulators = simulators

if isinstance(shared_simulator, FunctionType):
shared_simulator = LambdaSimulator(shared_simulator, is_batched=True)
self.shared_simulator = shared_simulator

match logits, p:
case (None, None):
logits = [0.0] * len(simulators)
Expand All @@ -43,30 +51,34 @@ def __init__(

@allow_batch_size
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
data = {}
if self.shared_simulator:
data |= self.shared_simulator.sample(batch_shape, **kwargs)

if not self.use_mixed_batches:
# draw one model index for the whole batch (faster)
model_index = np.random.choice(len(self.simulators), p=npu.softmax(self.logits))

simulator = self.simulators[model_index]
data = simulator.sample(batch_shape)
data = simulator.sample(batch_shape, **(kwargs | data))

model_indices = np.full(batch_shape, model_index, dtype="int32")
model_indices = npu.one_hot(model_indices, len(self.simulators))
else:
# draw a model index for each sample in the batch (slower)
model_indices = np.random.choice(len(self.simulators), p=npu.softmax(self.logits), size=batch_shape)

data = np.empty(batch_shape, dtype="object")

for index in np.ndindex(batch_shape):
simulator = self.simulators[int(model_indices[index])]
data[index] = simulator.sample(())
# generate data randomly from each model (slower)
model_counts = np.random.multinomial(n=batch_shape[0], pvals=npu.softmax(self.logits))

data = data.flatten().tolist()
data = tree_stack(data, axis=0, numpy=True)
sims = []
for n, simulator in zip(model_counts, self.simulators):
if n == 0:
continue
sim = simulator.sample(n, **(kwargs | data))
sims.append(sim)

# restore batch shape
data = {key: np.reshape(value, batch_shape + np.shape(value)[1:]) for key, value in data.items()}
sims = tree_concatenate(sims, numpy=True)
data |= sims

model_indices = npu.one_hot(model_indices, len(self.simulators))
model_indices = np.eye(len(self.simulators), dtype="int32")
model_indices = np.repeat(model_indices, model_counts, axis=0)

return data | {"model_indices": model_indices}
Loading