diff --git a/pastastore/base.py b/pastastore/base.py index 21a2f72..ac4b48f 100644 --- a/pastastore/base.py +++ b/pastastore/base.py @@ -15,6 +15,7 @@ import pastas as ps from numpy import isin from packaging.version import parse as parse_version +from pandas.testing import assert_series_equal from pastas.io.pas import PastasEncoder from tqdm.auto import tqdm @@ -30,7 +31,7 @@ class BaseConnector(ABC): Class holds base logic for dealing with time series and Pastas Models. Create your own Connector to a data source by writing a a class that inherits from this - BaseConnector. Your class has to override each abstractmethod and abstractproperty. + BaseConnector. Your class has to override each abstractmethod and property. """ _default_library_names = [ @@ -47,6 +48,10 @@ class BaseConnector(ABC): # True for pastas>=0.23.0 and False for pastas<=0.22.0 USE_PASTAS_VALIDATE_SERIES = False if PASTAS_LEQ_022 else True + # set series equality comparison settings (using assert_series_equal) + SERIES_EQUALITY_ABSOLUTE_TOLERANCE = 1e-10 + SERIES_EQUALITY_RELATIVE_TOLERANCE = 0.0 + def __repr__(self): """Representation string of the object.""" return ( @@ -670,22 +675,27 @@ def upsert_stress( metadata["kind"] = kind self._upsert_series("stresses", series, name, metadata=metadata) - def del_models(self, names: Union[list, str]) -> None: + def del_models(self, names: Union[list, str], verbose: bool = True) -> None: """Delete model(s) from the database. Parameters ---------- names : str or list of str name(s) of the model to delete + verbose : bool, optional + print information about deleted models, by default True """ - for n in self._parse_names(names, libname="models"): + names = self._parse_names(names, libname="models") + for n in names: mldict = self.get_models(n, return_dict=True) oname = mldict["oseries"]["name"] self._del_item("models", n) self._del_oseries_model_link(oname, n) self._clear_cache("_modelnames_cache") + if verbose: + print(f"Deleted {len(names)} model(s) from database.") - def del_model(self, names: Union[list, str]) -> None: + def del_model(self, names: Union[list, str], verbose: bool = True) -> None: """Delete model(s) from the database. Alias for del_models(). @@ -694,10 +704,14 @@ def del_model(self, names: Union[list, str]) -> None: ---------- names : str or list of str name(s) of the model to delete + verbose : bool, optional + print information about deleted models, by default True """ - self.del_models(names=names) + self.del_models(names=names, verbose=verbose) - def del_oseries(self, names: Union[list, str], remove_models: bool = False): + def del_oseries( + self, names: Union[list, str], remove_models: bool = False, verbose: bool = True + ): """Delete oseries from the database. Parameters @@ -706,29 +720,38 @@ def del_oseries(self, names: Union[list, str], remove_models: bool = False): name(s) of the oseries to delete remove_models : bool, optional also delete models for deleted oseries, default is False + verbose : bool, optional + print information about deleted oseries, by default True """ names = self._parse_names(names, libname="oseries") for n in names: self._del_item("oseries", n) self._clear_cache("oseries") + if verbose: + print(f"Deleted {len(names)} oseries from database.") # remove associated models from database if remove_models: modelnames = list( chain.from_iterable([self.oseries_models.get(n, []) for n in names]) ) - self.del_models(modelnames) + self.del_models(modelnames, verbose=verbose) - def del_stress(self, names: Union[list, str]): + def del_stress(self, names: Union[list, str], verbose: bool = True): """Delete stress from the database. Parameters ---------- names : str or list of str name(s) of the stress to delete + verbose : bool, optional + print information about deleted stresses, by default True """ - for n in self._parse_names(names, libname="stresses"): + names = self._parse_names(names, libname="stresses") + for n in names: self._del_item("stresses", n) self._clear_cache("stresses") + if verbose: + print(f"Deleted {len(names)} stress(es) from database.") def _get_series( self, @@ -1665,11 +1688,18 @@ def _check_oseries_in_store(self, ml: Union[ps.Model, dict]): so = ml.oseries.series_original else: so = ml.oseries._series_original - if not so.dropna().equals(s_org): + try: + assert_series_equal( + so.dropna(), + s_org, + atol=self.SERIES_EQUALITY_ABSOLUTE_TOLERANCE, + rtol=self.SERIES_EQUALITY_RELATIVE_TOLERANCE, + ) + except AssertionError as e: raise ValueError( f"Cannot add model because model oseries '{name}'" - " is different from stored oseries!" - ) + " is different from stored oseries! See stacktrace for differences." + ) from e def _check_stresses_in_store(self, ml: Union[ps.Model, dict]): """Check if stresses time series are contained in PastaStore (internal method). @@ -1699,11 +1729,19 @@ def _check_stresses_in_store(self, ml: Union[ps.Model, dict]): so = s.series_original else: so = s._series_original - if not so.equals(s_org): + try: + assert_series_equal( + so, + s_org, + atol=self.SERIES_EQUALITY_ABSOLUTE_TOLERANCE, + rtol=self.SERIES_EQUALITY_RELATIVE_TOLERANCE, + ) + except AssertionError as e: raise ValueError( f"Cannot add model because model stress " - f"'{s.name}' is different from stored stress!" - ) + f"'{s.name}' is different from stored stress! " + "See stacktrace for differences." + ) from e elif isinstance(ml, dict): for sm in ml["stressmodels"].values(): classkey = "stressmodel" if PASTAS_LEQ_022 else "class"