From c4eb3f1792a7e3b46d86205ff47271e0181096c1 Mon Sep 17 00:00:00 2001 From: Enio Hayashi Date: Wed, 22 Jan 2025 17:31:27 -0300 Subject: [PATCH 1/2] Implement equality operator for UQ classes, --- src/alfasim_sdk/result_reader/reader.py | 42 +++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/src/alfasim_sdk/result_reader/reader.py b/src/alfasim_sdk/result_reader/reader.py index 1a3c1492..f62a79cc 100644 --- a/src/alfasim_sdk/result_reader/reader.py +++ b/src/alfasim_sdk/result_reader/reader.py @@ -393,6 +393,23 @@ def get_sensitivity_curve( domain = Array(self.timeset, "s") return Curve(image=image, domain=domain) + def __eq__(self, other): + if other.__class__ is not self.__class__: + return False + + def compare_coefficients(dict_a, dict_b): + if dict_a.keys() != dict_b.keys(): + return False + for key in dict_a.keys(): + if (dict_a[key] != dict_b[key]).any(): + return False + return True + + return ( + (self.timeset == other.timeset).all() and + compare_coefficients(self.coefficients, other.coefficients) and + self.metadata == other.metadata + ) @define(frozen=True) class _BaseHistoryMatchingResults: @@ -443,6 +460,23 @@ def from_directory(cls, result_dir: Path) -> Self | None: metadata=metadata, ) + def __eq__(self, other): + if other.__class__ is not self.__class__: + return False + + def compare_probabilistic_distributions(dict_a, dict_b): + if dict_a.keys() != dict_b.keys(): + return False + for key in dict_a.keys(): + if (dict_a[key] != dict_b[key]).any(): + return False + return True + + return ( + self.historic_data_curves == other.historic_data_curves and + self.metadata == other.metadata and + compare_probabilistic_distributions(self.probabilistic_distributions, other.probabilistic_distributions) + ) def _read_curves_data( metadata: HistoryMatchingMetadata, @@ -481,3 +515,11 @@ def from_directory(cls, result_dir: Path) -> Self | None: results=read_uncertainty_propagation_results(metadata), metadata=metadata, ) + + def __eq__(self, other): + if other.__class__ is not self.__class__: + return False + return ( + (self.timeset == other.timeset).all() and + self.metadata == other.metadata + ) From 11fe0e79baf3478cc4bec3a6b28d17d616428776 Mon Sep 17 00:00:00 2001 From: Enio Hayashi Date: Thu, 23 Jan 2025 18:04:03 -0300 Subject: [PATCH 2/2] Fix instance check --- src/alfasim_sdk/result_reader/reader.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/alfasim_sdk/result_reader/reader.py b/src/alfasim_sdk/result_reader/reader.py index f62a79cc..7450b425 100644 --- a/src/alfasim_sdk/result_reader/reader.py +++ b/src/alfasim_sdk/result_reader/reader.py @@ -394,9 +394,8 @@ def get_sensitivity_curve( return Curve(image=image, domain=domain) def __eq__(self, other): - if other.__class__ is not self.__class__: + if not isinstance(other, GlobalSensitivityAnalysisResults): return False - def compare_coefficients(dict_a, dict_b): if dict_a.keys() != dict_b.keys(): return False @@ -461,9 +460,8 @@ def from_directory(cls, result_dir: Path) -> Self | None: ) def __eq__(self, other): - if other.__class__ is not self.__class__: + if not isinstance(other, HistoryMatchingProbabilisticResults): return False - def compare_probabilistic_distributions(dict_a, dict_b): if dict_a.keys() != dict_b.keys(): return False @@ -517,7 +515,7 @@ def from_directory(cls, result_dir: Path) -> Self | None: ) def __eq__(self, other): - if other.__class__ is not self.__class__: + if not isinstance(other, UncertaintyPropagationResults): return False return ( (self.timeset == other.timeset).all() and