diff --git a/evofr/posterior/posterior_handler.py b/evofr/posterior/posterior_handler.py index 6a7bb7b..b55779d 100644 --- a/evofr/posterior/posterior_handler.py +++ b/evofr/posterior/posterior_handler.py @@ -1,10 +1,80 @@ import json +import pickle from typing import Dict, List, Optional from evofr.data.data_spec import DataSpec from evofr.posterior.posterior_helpers import EvofrEncoder +def determine_method(filepath): + """ + Determines the serialization method based on the file extension. + + Parameters: + filepath (str): The path of the file including its extension. + + Returns: + str: The serialization method ("json" or "pickle"). + + Raises: + ValueError: If the file extension is not recognized. + """ + import os + + _, ext = os.path.splitext(filepath) + if ext.lower() == ".json": + return "json" + elif ext.lower() == ".pkl": + return "pickle" + return None + + +def save_data(data, filename, method="json"): + """ + Save data to a file using either JSON or pickle based on the user's choice. + + Parameters: + - data: The data to be serialized and saved. + - filename: The filename where the data will be saved. + - method: The serialization method ('json' or 'pickle'). + + Raises: + - ValueError: If the provided method is not supported. + """ + if method == "json": + with open(filename, "w") as file: + json.dump(data, file, cls=EvofrEncoder) + elif method == "pickle": + with open(filename, "wb") as file: + pickle.dump(data, file) + else: + raise ValueError("Unsupported serialization method. Use 'json' or 'pickle'.") + + +def load_data(filename, method="json"): + """ + Load data from a file using either JSON or pickle based on the user's choice. + + Parameters: + - filename: The filename from which the data will be loaded. + - method: The serialization method ('json' or 'pickle'). + + Returns: + - The data loaded from the file. + + Raises: + - ValueError: If the provided method is not supported. + """ + if method == "json": + with open(filename, "r") as file: + return json.load(file) + elif method == "pickle": + with open(filename, "rb") as file: + return pickle.load(file) + else: + raise ValueError("Unsupported deserialization method. Use 'json' or 'pickle'.") + + class PosteriorHandler: def __init__( self, @@ -30,16 +100,27 @@ def __init__( self.data = data self.name = name - def save_posterior(self, filepath: str): + def save_posterior(self, filepath: str, method=None): + if method is None: + method = determine_method(filepath) + assert ( + method is not None + ), """Serialization method could not be determined from `filepath`. + Please define explicitly or use compatiable file extension e.g. .json or .pkl)""" + """Save posterior samples at a given filepath.""" - if self.samples is not None: - with open(filepath, "w") as file: - json.dump(self.samples, file, cls=EvofrEncoder) + save_data(self.samples, filepath, method=method) - def load_posterior(self, filepath: str): + def load_posterior(self, filepath: str, method=None): """Load posterior samples from a given filepath.""" - with open(filepath, "w") as file: - self.samples = json.load(file) + if method is None: + method = determine_method(filepath) + assert ( + method is not None + ), """Serialization method could not be determined from `filepath`. + Please define explicitly or use compatiable file extension e.g. .json or .pkl)""" + self.samples = load_data(filepath, method=method) + return self def unpack_posterior(self): """Return samples and dataspec from PosteriorHandler."""