Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to save a trained ContinuousApproximator model? #302

Closed
philippreiser opened this issue Feb 11, 2025 · 3 comments
Closed

How to save a trained ContinuousApproximator model? #302

philippreiser opened this issue Feb 11, 2025 · 3 comments
Assignees

Comments

@philippreiser
Copy link

philippreiser commented Feb 11, 2025

Hi!

Description

I am trying to save a trained ContinuousApproximator model on the dev branch using approximator.save("approximator.keras"), but I am unsure if this is the correct approach. Here's a minimal example based on the Linear Regression example notebook:

import os

if "KERAS_BACKEND" not in os.environ:
    # set this to "torch", "tensorflow", or "jax"
    os.environ["KERAS_BACKEND"] = "torch"

import bayesflow as bf
import keras
import numpy as np

def prior():
    # beta: regression coefficients (intercept, slope)
    beta = np.random.normal([2, 0], [3, 1])
    return dict(beta=beta)

def likelihood(beta):
    # x: predictor variable
    x = np.random.normal(0, 1, size=10)
    # y: response variable
    y = np.random.normal(beta[0] + beta[1] * x, size=10)
    return dict(y=y, x=x)

simulator = bf.simulators.make_simulator([prior, likelihood])
adapter = (
    bf.Adapter()
    .as_set(["x", "y"])
    .standardize()
    .concatenate(["beta"], into="inference_variables")
    .concatenate(["x", "y"], into="summary_variables")
)
inference_network = bf.networks.FlowMatching()
summary_network = bf.networks.DeepSet(depth=2)
approximator = bf.ContinuousApproximator(
   inference_network=inference_network,
   summary_network=summary_network,
   adapter=adapter,
)
epochs = 1
num_batches = 1
batch_size = 1
optimizer = keras.optimizers.Adam(learning_rate=5e-4, clipnorm=1.0)
approximator.compile(optimizer=optimizer)
history = approximator.fit(
    epochs=epochs,
    num_batches=num_batches,
    batch_size=batch_size,
    simulator=simulator,
)
approximator.save("approximator.keras")

Error Message:

NotImplementedError                       Traceback (most recent call last)
Cell In[12], line 1
----> 1 approximator.save("approximator.keras")

File ~/.conda/envs/sabi_env/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File ~/.conda/envs/sabi_env/lib/python3.11/site-packages/bayesflow/approximators/continuous_approximator.py:127, in ContinuousApproximator.get_config(self)
    124 def get_config(self):
    125     base_config = super().get_config()
    126     config = {
--> 127         "adapter": serialize(self.adapter),
    128         "inference_network": serialize(self.inference_network),
    129         "summary_network": serialize(self.summary_network),
    130     }
    132     return base_config | config

File ~/.conda/envs/sabi_env/lib/python3.11/site-packages/bayesflow/adapters/adapter.py:55, in Adapter.get_config(self)
     54 def get_config(self) -> dict:
...
File ~/.conda/envs/sabi_env/lib/python3.11/site-packages/bayesflow/adapters/transforms/elementwise_transform.py:20, in ElementwiseTransform.get_config(self)
     19 def get_config(self) -> dict:
---> 20     raise NotImplementedError
@philippreiser philippreiser changed the title How to save a trained ContinuousApproximator model in BayesFlow? How to save a trained ContinuousApproximator model? Feb 11, 2025
@vpratz
Copy link
Collaborator

vpratz commented Feb 11, 2025

Thanks for reporting this. The adapter is not fully serializable yet, but I think we should be able to finish that soon. I will take a look and let you know, if not we should be able to provide a temporary workaround.

@vpratz
Copy link
Collaborator

vpratz commented Feb 11, 2025

There was some code missing to enable the serialization for the AsSet transform in the adapter. This should be fixed once #306 is merged. Thanks again for reporting the issue, please let us know if you encounter any other bugs or difficulties.

@vpratz vpratz closed this as completed Feb 11, 2025
@philippreiser
Copy link
Author

philippreiser commented Feb 11, 2025

Thank you for this quick fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants