Skip to content

Commit

Permalink
Adding option to save and load posteriors with pickle and json
Browse files Browse the repository at this point in the history
  • Loading branch information
marlinfiggins committed Jun 1, 2024
1 parent fb67e76 commit e7cf95c
Showing 1 changed file with 88 additions and 7 deletions.
95 changes: 88 additions & 7 deletions evofr/posterior/posterior_handler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,80 @@
import json
import pickle
from typing import Dict, List, Optional

from evofr.data.data_spec import DataSpec
from evofr.posterior.posterior_helpers import EvofrEncoder


def determine_method(filepath):
"""
Determines the serialization method based on the file extension.
Parameters:
filepath (str): The path of the file including its extension.
Returns:
str: The serialization method ("json" or "pickle").
Raises:
ValueError: If the file extension is not recognized.
"""
import os

_, ext = os.path.splitext(filepath)
if ext.lower() == ".json":
return "json"
elif ext.lower() == ".pkl":
return "pickle"
return None


def save_data(data, filename, method="json"):
"""
Save data to a file using either JSON or pickle based on the user's choice.
Parameters:
- data: The data to be serialized and saved.
- filename: The filename where the data will be saved.
- method: The serialization method ('json' or 'pickle').
Raises:
- ValueError: If the provided method is not supported.
"""
if method == "json":
with open(filename, "w") as file:
json.dump(data, file, cls=EvofrEncoder)
elif method == "pickle":
with open(filename, "wb") as file:
pickle.dump(data, file)
else:
raise ValueError("Unsupported serialization method. Use 'json' or 'pickle'.")


def load_data(filename, method="json"):
"""
Load data from a file using either JSON or pickle based on the user's choice.
Parameters:
- filename: The filename from which the data will be loaded.
- method: The serialization method ('json' or 'pickle').
Returns:
- The data loaded from the file.
Raises:
- ValueError: If the provided method is not supported.
"""
if method == "json":
with open(filename, "r") as file:
return json.load(file)
elif method == "pickle":
with open(filename, "rb") as file:
return pickle.load(file)
else:
raise ValueError("Unsupported deserialization method. Use 'json' or 'pickle'.")


class PosteriorHandler:
def __init__(
self,
Expand All @@ -30,16 +100,27 @@ def __init__(
self.data = data
self.name = name

def save_posterior(self, filepath: str):
def save_posterior(self, filepath: str, method=None):
if method is None:
method = determine_method(filepath)
assert (
method is not None
), """Serialization method could not be determined from `filepath`.
Please define explicitly or use compatiable file extension e.g. .json or .pkl)"""

"""Save posterior samples at a given filepath."""
if self.samples is not None:
with open(filepath, "w") as file:
json.dump(self.samples, file, cls=EvofrEncoder)
save_data(self.samples, filepath, method=method)

def load_posterior(self, filepath: str):
def load_posterior(self, filepath: str, method=None):
"""Load posterior samples from a given filepath."""
with open(filepath, "w") as file:
self.samples = json.load(file)
if method is None:
method = determine_method(filepath)
assert (
method is not None
), """Serialization method could not be determined from `filepath`.
Please define explicitly or use compatiable file extension e.g. .json or .pkl)"""
self.samples = load_data(filepath, method=method)
return self

def unpack_posterior(self):
"""Return samples and dataspec from PosteriorHandler."""
Expand Down

0 comments on commit e7cf95c

Please sign in to comment.