From ee11a7b90962295ee4e8b2ac10621b8b1c3f54fb Mon Sep 17 00:00:00 2001 From: SamTov Date: Thu, 25 May 2023 22:57:53 +0200 Subject: [PATCH 1/2] Merge branch 'main' into SamTov_analysis_signature --- znnl/distance_metrics/__init__.py | 2 + znnl/distance_metrics/angular_distance.py | 11 +++ znnl/distance_metrics/cosine_distance.py | 11 +++ .../cross_entropy_distance.py | 50 +++++++++++ znnl/distance_metrics/distance_metric.py | 27 +++++- .../distance_metrics/hyper_sphere_distance.py | 11 +++ znnl/distance_metrics/l_p_norm.py | 11 +++ znnl/distance_metrics/mahalanobis_distance.py | 11 +++ znnl/distance_metrics/order_n_difference.py | 11 +++ znnl/loss_functions/__init__.py | 4 +- .../absolute_angle_difference.py | 15 +++- znnl/loss_functions/cosine_distance.py | 15 +++- znnl/loss_functions/cross_entropy_loss.py | 37 ++++----- znnl/loss_functions/l_p_norm.py | 13 ++- .../{simple_loss.py => loss.py} | 29 ++++++- znnl/loss_functions/mahalanobis.py | 13 ++- znnl/loss_functions/mean_power_error.py | 15 +++- znnl/observables/__init__.py | 32 +++++++ znnl/observables/observable.py | 83 +++++++++++++++++++ 19 files changed, 363 insertions(+), 38 deletions(-) create mode 100644 znnl/distance_metrics/cross_entropy_distance.py rename znnl/loss_functions/{simple_loss.py => loss.py} (73%) create mode 100644 znnl/observables/__init__.py create mode 100644 znnl/observables/observable.py diff --git a/znnl/distance_metrics/__init__.py b/znnl/distance_metrics/__init__.py index 25630d4..04010f0 100644 --- a/znnl/distance_metrics/__init__.py +++ b/znnl/distance_metrics/__init__.py @@ -31,6 +31,7 @@ from znnl.distance_metrics.l_p_norm import LPNorm from znnl.distance_metrics.mahalanobis_distance import MahalanobisDistance from znnl.distance_metrics.order_n_difference import OrderNDifference +from znnl.distance_metrics.cross_entropy_distance import CrossEntropyDistance __all__ = [ DistanceMetric.__name__, @@ -40,4 +41,5 @@ OrderNDifference.__name__, MahalanobisDistance.__name__, HyperSphere.__name__, + CrossEntropyDistance.__name__ ] diff --git a/znnl/distance_metrics/angular_distance.py b/znnl/distance_metrics/angular_distance.py index cc80d94..40cceb7 100644 --- a/znnl/distance_metrics/angular_distance.py +++ b/znnl/distance_metrics/angular_distance.py @@ -49,6 +49,17 @@ def __init__(self, points: int = None): self.normalization = points / np.pi else: raise ValueError("Invalid points input.") + + def __name__(self): + """ + Name of the class. + + Returns + ------- + name : str + The name of the class. + """ + return f"angular_distance" def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs): """ diff --git a/znnl/distance_metrics/cosine_distance.py b/znnl/distance_metrics/cosine_distance.py index 103f68b..ac13c41 100644 --- a/znnl/distance_metrics/cosine_distance.py +++ b/znnl/distance_metrics/cosine_distance.py @@ -68,3 +68,14 @@ def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs): ) return 1 - abs(np.divide(numerator, denominator)) + + def __name__(self): + """ + Name of the class. + + Returns + ------- + name : str + The name of the class. + """ + return f"cosine_distance" diff --git a/znnl/distance_metrics/cross_entropy_distance.py b/znnl/distance_metrics/cross_entropy_distance.py new file mode 100644 index 0000000..b9bbc69 --- /dev/null +++ b/znnl/distance_metrics/cross_entropy_distance.py @@ -0,0 +1,50 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" +import optax + +from znnl.distance_metrics.distance_metric import DistanceMetric + + +class CrossEntropyDistance(DistanceMetric): + """ + Class for the cross entropy distance + """ + + def __call__(self, prediction, target): + """ + + Parameters + ---------- + prediction (batch_size, n_classes) + target + + Returns + ------- + Softmax cross entropy of the batch. + + """ + return optax.softmax_cross_entropy(logits=prediction, labels=target) diff --git a/znnl/distance_metrics/distance_metric.py b/znnl/distance_metrics/distance_metric.py index 8c68e78..d98a60e 100644 --- a/znnl/distance_metrics/distance_metric.py +++ b/znnl/distance_metrics/distance_metric.py @@ -26,12 +26,37 @@ """ import jax.numpy as np +from znnl.observables.observable import Observable -class DistanceMetric: + +class DistanceMetric(Observable): """ Parent class for a ZnRND distance metric. """ + def __name__(self) -> str: + """ + Name of the class. + + Returns + ------- + name : str + The name of the class. + """ + return "distance_metric" + + def __signature__(self) -> tuple: + """ + Signature of the class. + + Returns + ------- + signature : tuple + The signature of the class. + For the distance metric, it is (1,). + """ + return (1,) + def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs): """ Call the distance metric. diff --git a/znnl/distance_metrics/hyper_sphere_distance.py b/znnl/distance_metrics/hyper_sphere_distance.py index 1ea5100..4724fb6 100644 --- a/znnl/distance_metrics/hyper_sphere_distance.py +++ b/znnl/distance_metrics/hyper_sphere_distance.py @@ -73,3 +73,14 @@ def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs): return LPNorm(order=self.order)(point_1, point_2) * CosineDistance()( point_1, point_2 ) + + def __name__(self): + """ + Name of the class. + + Returns + ------- + name : str + The name of the class. + """ + return f"hyper_sphere_distance_{self.order}" diff --git a/znnl/distance_metrics/l_p_norm.py b/znnl/distance_metrics/l_p_norm.py index f3492e3..91947a6 100644 --- a/znnl/distance_metrics/l_p_norm.py +++ b/znnl/distance_metrics/l_p_norm.py @@ -69,3 +69,14 @@ def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs): Array of distances for each point. """ return np.linalg.norm(point_1 - point_2, axis=1, ord=self.order) + + def __name__(self): + """ + Name of the class. + + Returns + ------- + name : str + The name of the class. + """ + return f"lp_norm_{self.order}" diff --git a/znnl/distance_metrics/mahalanobis_distance.py b/znnl/distance_metrics/mahalanobis_distance.py index 832f377..af7fc0f 100644 --- a/znnl/distance_metrics/mahalanobis_distance.py +++ b/znnl/distance_metrics/mahalanobis_distance.py @@ -66,3 +66,14 @@ def __call__(self, point_1: np.array, point_2: np.array, **kwargs) -> np.array: distances.append(distance) return distances + + def __name__(self): + """ + Name of the class. + + Returns + ------- + name : str + The name of the class. + """ + return f"mahalanobis_distance" diff --git a/znnl/distance_metrics/order_n_difference.py b/znnl/distance_metrics/order_n_difference.py index 7963d49..feff1b0 100644 --- a/znnl/distance_metrics/order_n_difference.py +++ b/znnl/distance_metrics/order_n_difference.py @@ -79,3 +79,14 @@ def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs): return np.sum(np.power(diff, self.order), axis=1) else: raise ValueError(f"Invalid reduction operation: {self.reduce_operation}") + + def __name__(self): + """ + Name of the class. + + Returns + ------- + name : str + The name of the class. + """ + return f"order_{self.order}_difference_{self.reduce_operation}" diff --git a/znnl/loss_functions/__init__.py b/znnl/loss_functions/__init__.py index 8abcbd1..230d8e5 100644 --- a/znnl/loss_functions/__init__.py +++ b/znnl/loss_functions/__init__.py @@ -30,7 +30,7 @@ from znnl.loss_functions.l_p_norm import LPNormLoss from znnl.loss_functions.mahalanobis import MahalanobisLoss from znnl.loss_functions.mean_power_error import MeanPowerLoss -from znnl.loss_functions.simple_loss import SimpleLoss +from znnl.loss_functions.loss import Loss __all__ = [ AngleDistanceLoss.__name__, @@ -38,6 +38,6 @@ LPNormLoss.__name__, MahalanobisLoss.__name__, MeanPowerLoss.__name__, - SimpleLoss.__name__, + Loss.__name__, CrossEntropyLoss.__name__, ] diff --git a/znnl/loss_functions/absolute_angle_difference.py b/znnl/loss_functions/absolute_angle_difference.py index 4c521f3..e195476 100644 --- a/znnl/loss_functions/absolute_angle_difference.py +++ b/znnl/loss_functions/absolute_angle_difference.py @@ -25,10 +25,10 @@ ------- """ from znnl.distance_metrics.angular_distance import AngularDistance -from znnl.loss_functions.simple_loss import SimpleLoss +from znnl.loss_functions.loss import Loss -class AngleDistanceLoss(SimpleLoss): +class AngleDistanceLoss(Loss): """ Class for the mean power loss """ @@ -39,3 +39,14 @@ def __init__(self): """ super(AngleDistanceLoss, self).__init__() self.metric = AngularDistance() + + def __name__(self): + """ + Name of the class. + + Returns + ------- + name : str + The name of the class. + """ + return f"angle_distance_loss" diff --git a/znnl/loss_functions/cosine_distance.py b/znnl/loss_functions/cosine_distance.py index 0768f34..ab443fc 100644 --- a/znnl/loss_functions/cosine_distance.py +++ b/znnl/loss_functions/cosine_distance.py @@ -25,10 +25,10 @@ ------- """ from znnl.distance_metrics.cosine_distance import CosineDistance -from znnl.loss_functions.simple_loss import SimpleLoss +from znnl.loss_functions.loss import Loss -class CosineDistanceLoss(SimpleLoss): +class CosineDistanceLoss(Loss): """ Class for the mean power loss """ @@ -39,3 +39,14 @@ def __init__(self): """ super(CosineDistanceLoss, self).__init__() self.metric = CosineDistance() + + def __name__(self): + """ + Name of the class. + + Returns + ------- + name : str + The name of the class. + """ + return f"cosine_distance_loss" diff --git a/znnl/loss_functions/cross_entropy_loss.py b/znnl/loss_functions/cross_entropy_loss.py index 67a3250..2c92ad0 100644 --- a/znnl/loss_functions/cross_entropy_loss.py +++ b/znnl/loss_functions/cross_entropy_loss.py @@ -26,31 +26,11 @@ """ import optax -from znnl.loss_functions.simple_loss import SimpleLoss +from znnl.loss_functions.loss import Loss +from znnl.distance_metrics.cross_entropy_distance import CrossEntropyDistance -class CrossEntropyDistance: - """ - Class for the cross entropy distance - """ - - def __call__(self, prediction, target): - """ - - Parameters - ---------- - prediction (batch_size, n_classes) - target - - Returns - ------- - Softmax cross entropy of the batch. - - """ - return optax.softmax_cross_entropy(logits=prediction, labels=target) - - -class CrossEntropyLoss(SimpleLoss): +class CrossEntropyLoss(Loss): """ Class for the cross entropy loss """ @@ -61,3 +41,14 @@ def __init__(self): """ super(CrossEntropyLoss, self).__init__() self.metric = CrossEntropyDistance() + + def __name__(self): + """ + Name of the class. + + Returns + ------- + name : str + The name of the class. + """ + return f"cross_entropy_loss" diff --git a/znnl/loss_functions/l_p_norm.py b/znnl/loss_functions/l_p_norm.py index 97066b8..628a2e1 100644 --- a/znnl/loss_functions/l_p_norm.py +++ b/znnl/loss_functions/l_p_norm.py @@ -25,7 +25,7 @@ ------- """ from znnl.distance_metrics.l_p_norm import LPNorm -from znnl.loss_functions.simple_loss import SimpleLoss +from znnl.loss_functions.loss import SimpleLoss class LPNormLoss(SimpleLoss): @@ -44,3 +44,14 @@ def __init__(self, order: float): """ super(LPNormLoss, self).__init__() self.metric = LPNorm(order=order) + + def __name__(self): + """ + Name of the class. + + Returns + ------- + name : str + The name of the class. + """ + return f"lp_norm_loss_{self.order}" diff --git a/znnl/loss_functions/simple_loss.py b/znnl/loss_functions/loss.py similarity index 73% rename from znnl/loss_functions/simple_loss.py rename to znnl/loss_functions/loss.py index 9894b3c..e5dc441 100644 --- a/znnl/loss_functions/simple_loss.py +++ b/znnl/loss_functions/loss.py @@ -24,16 +24,15 @@ Summary ------- """ -from abc import ABC - import jax.numpy as np from znnl.distance_metrics.distance_metric import DistanceMetric +from znnl.observables.observable import Observable -class SimpleLoss(ABC): +class Loss(Observable): """ - Class for the simple loss. + Parent class for the loss. Attributes ---------- @@ -47,6 +46,28 @@ def __init__(self): super().__init__() self.metric: DistanceMetric = None + def __name__(self) -> str: + """ + Name of the class. + + Returns + ------- + name : str + The name of the class. + """ + return "loss_parent" + + def __signature__(self) -> tuple: + """ + Signature of the class. + + Returns + ------- + signature : tuple + For loss this should always be (1,) + """ + return (1,) + def __call__(self, point_1: np.array, point_2: np.array) -> float: """ Summation over the tensor of the respective similarity measurement diff --git a/znnl/loss_functions/mahalanobis.py b/znnl/loss_functions/mahalanobis.py index 1013f95..281d266 100644 --- a/znnl/loss_functions/mahalanobis.py +++ b/znnl/loss_functions/mahalanobis.py @@ -25,7 +25,7 @@ ------- """ import znnl.distance_metrics.mahalanobis_distance as mahalanobis -from znnl.loss_functions.simple_loss import SimpleLoss +from znnl.loss_functions.loss import SimpleLoss class MahalanobisLoss(SimpleLoss): @@ -39,3 +39,14 @@ def __init__(self): """ super(MahalanobisLoss, self).__init__() self.metric = mahalanobis.MahalanobisDistance() + + def __name__(self): + """ + Name of the class. + + Returns + ------- + name : str + The name of the class. + """ + return f"mahalanobis_loss" diff --git a/znnl/loss_functions/mean_power_error.py b/znnl/loss_functions/mean_power_error.py index e6ae162..4b9e6c6 100644 --- a/znnl/loss_functions/mean_power_error.py +++ b/znnl/loss_functions/mean_power_error.py @@ -25,10 +25,10 @@ ------- """ from znnl.distance_metrics.order_n_difference import OrderNDifference -from znnl.loss_functions.simple_loss import SimpleLoss +from znnl.loss_functions.loss import Loss -class MeanPowerLoss(SimpleLoss): +class MeanPowerLoss(Loss): """ Class for the mean power loss """ @@ -44,3 +44,14 @@ def __init__(self, order: float): """ super(MeanPowerLoss, self).__init__() self.metric = OrderNDifference(order=order) + + def __name__(self): + """ + Name of the class. + + Returns + ------- + name : str + The name of the class. + """ + return f"mean_power_loss_{self.order}" diff --git a/znnl/observables/__init__.py b/znnl/observables/__init__.py new file mode 100644 index 0000000..7ac9b0d --- /dev/null +++ b/znnl/observables/__init__.py @@ -0,0 +1,32 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Module for ZnNL observables. +""" +from znnl.observables.observable import Observable + +__all__ = [ + Observable.__name__ +] diff --git a/znnl/observables/observable.py b/znnl/observables/observable.py new file mode 100644 index 0000000..8dd5bd6 --- /dev/null +++ b/znnl/observables/observable.py @@ -0,0 +1,83 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Observable parent class module. +""" +from typing import Union + +import jax.numpy as np + + +class Observable: + """ + Parent class for all observables. + """ + + @classmethod + def __name__(self) -> str: + """ + Name of the observable. + + Returns + ------- + name : str + The name of the observable. + """ + raise NotImplementedError("Implemented in child class.") + + @classmethod + def __signature__(self, data_set: dict) -> tuple: + """ + Name of the observable. + + Parameters + ---------- + data_set : dict + The data set to be used for the observable. + + Returns + ------- + signature : tuple + The signature of the observable. + """ + raise NotImplementedError("Implemented in child class.") + + @classmethod + def __call__(self, data_set: dict) -> Union[str, np.ndarray, float]: + """ + Compute the observable. + + Parameters + ---------- + data_set : dict + The data set to be used for the observable. + + Returns + ------- + value : Union[str, np.ndarray, float] + The value of the observable. + + """ + raise NotImplementedError("Implemented in child class.") From f2861780fc63eecc59d4cf90204ceaf7b96f8a81 Mon Sep 17 00:00:00 2001 From: SamTov Date: Fri, 26 May 2023 11:35:10 +0200 Subject: [PATCH 2/2] update pre-commit and fix formatting --- .pre-commit-config.yaml | 6 +++--- znnl/distance_metrics/__init__.py | 4 ++-- znnl/distance_metrics/angular_distance.py | 4 ++-- znnl/distance_metrics/cosine_distance.py | 4 ++-- .../distance_metrics/cross_entropy_distance.py | 11 +++++++++++ znnl/distance_metrics/distance_metric.py | 2 +- znnl/distance_metrics/hyper_sphere_distance.py | 2 +- znnl/distance_metrics/l_p_norm.py | 2 +- znnl/distance_metrics/mahalanobis_distance.py | 4 ++-- znnl/distance_metrics/order_n_difference.py | 4 ++-- znnl/loss_functions/__init__.py | 2 +- .../absolute_angle_difference.py | 18 +++++++++--------- znnl/loss_functions/cosine_distance.py | 2 +- znnl/loss_functions/cross_entropy_loss.py | 6 ++---- znnl/loss_functions/loss.py | 2 +- znnl/loss_functions/mahalanobis.py | 2 +- znnl/loss_functions/mean_power_error.py | 2 +- znnl/observables/__init__.py | 4 +--- znnl/observables/observable.py | 6 +++--- .../partitioned_training.py | 6 +++--- 20 files changed, 50 insertions(+), 43 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d3f4c24..29aa27f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,17 +4,17 @@ fail_fast: true repos: - repo: https://github.com/psf/black - rev: 22.8.0 + rev: 23.3.0 hooks: - id: black - repo: https://github.com/timothycrosley/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort - repo: https://github.com/pycqa/flake8 - rev: 5.0.4 + rev: 6.0.0 hooks: - id: flake8 additional_dependencies: [flake8-isort] diff --git a/znnl/distance_metrics/__init__.py b/znnl/distance_metrics/__init__.py index 04010f0..5e85a97 100644 --- a/znnl/distance_metrics/__init__.py +++ b/znnl/distance_metrics/__init__.py @@ -26,12 +26,12 @@ """ from znnl.distance_metrics.angular_distance import AngularDistance from znnl.distance_metrics.cosine_distance import CosineDistance +from znnl.distance_metrics.cross_entropy_distance import CrossEntropyDistance from znnl.distance_metrics.distance_metric import DistanceMetric from znnl.distance_metrics.hyper_sphere_distance import HyperSphere from znnl.distance_metrics.l_p_norm import LPNorm from znnl.distance_metrics.mahalanobis_distance import MahalanobisDistance from znnl.distance_metrics.order_n_difference import OrderNDifference -from znnl.distance_metrics.cross_entropy_distance import CrossEntropyDistance __all__ = [ DistanceMetric.__name__, @@ -41,5 +41,5 @@ OrderNDifference.__name__, MahalanobisDistance.__name__, HyperSphere.__name__, - CrossEntropyDistance.__name__ + CrossEntropyDistance.__name__, ] diff --git a/znnl/distance_metrics/angular_distance.py b/znnl/distance_metrics/angular_distance.py index 40cceb7..cb3e258 100644 --- a/znnl/distance_metrics/angular_distance.py +++ b/znnl/distance_metrics/angular_distance.py @@ -49,7 +49,7 @@ def __init__(self, points: int = None): self.normalization = points / np.pi else: raise ValueError("Invalid points input.") - + def __name__(self): """ Name of the class. @@ -59,7 +59,7 @@ def __name__(self): name : str The name of the class. """ - return f"angular_distance" + return "angular_distance" def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs): """ diff --git a/znnl/distance_metrics/cosine_distance.py b/znnl/distance_metrics/cosine_distance.py index ac13c41..2573478 100644 --- a/znnl/distance_metrics/cosine_distance.py +++ b/znnl/distance_metrics/cosine_distance.py @@ -68,7 +68,7 @@ def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs): ) return 1 - abs(np.divide(numerator, denominator)) - + def __name__(self): """ Name of the class. @@ -78,4 +78,4 @@ def __name__(self): name : str The name of the class. """ - return f"cosine_distance" + return "cosine_distance" diff --git a/znnl/distance_metrics/cross_entropy_distance.py b/znnl/distance_metrics/cross_entropy_distance.py index b9bbc69..3724613 100644 --- a/znnl/distance_metrics/cross_entropy_distance.py +++ b/znnl/distance_metrics/cross_entropy_distance.py @@ -34,6 +34,17 @@ class CrossEntropyDistance(DistanceMetric): Class for the cross entropy distance """ + def __name__(self): + """ + Name of the class. + + Returns + ------- + name : str + The name of the class. + """ + return "cross_entropy_distance" + def __call__(self, prediction, target): """ diff --git a/znnl/distance_metrics/distance_metric.py b/znnl/distance_metrics/distance_metric.py index d98a60e..d273e04 100644 --- a/znnl/distance_metrics/distance_metric.py +++ b/znnl/distance_metrics/distance_metric.py @@ -44,7 +44,7 @@ def __name__(self) -> str: The name of the class. """ return "distance_metric" - + def __signature__(self) -> tuple: """ Signature of the class. diff --git a/znnl/distance_metrics/hyper_sphere_distance.py b/znnl/distance_metrics/hyper_sphere_distance.py index 4724fb6..894f406 100644 --- a/znnl/distance_metrics/hyper_sphere_distance.py +++ b/znnl/distance_metrics/hyper_sphere_distance.py @@ -73,7 +73,7 @@ def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs): return LPNorm(order=self.order)(point_1, point_2) * CosineDistance()( point_1, point_2 ) - + def __name__(self): """ Name of the class. diff --git a/znnl/distance_metrics/l_p_norm.py b/znnl/distance_metrics/l_p_norm.py index 91947a6..787a58e 100644 --- a/znnl/distance_metrics/l_p_norm.py +++ b/znnl/distance_metrics/l_p_norm.py @@ -69,7 +69,7 @@ def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs): Array of distances for each point. """ return np.linalg.norm(point_1 - point_2, axis=1, ord=self.order) - + def __name__(self): """ Name of the class. diff --git a/znnl/distance_metrics/mahalanobis_distance.py b/znnl/distance_metrics/mahalanobis_distance.py index af7fc0f..3aaae8a 100644 --- a/znnl/distance_metrics/mahalanobis_distance.py +++ b/znnl/distance_metrics/mahalanobis_distance.py @@ -66,7 +66,7 @@ def __call__(self, point_1: np.array, point_2: np.array, **kwargs) -> np.array: distances.append(distance) return distances - + def __name__(self): """ Name of the class. @@ -76,4 +76,4 @@ def __name__(self): name : str The name of the class. """ - return f"mahalanobis_distance" + return "mahalanobis_distance" diff --git a/znnl/distance_metrics/order_n_difference.py b/znnl/distance_metrics/order_n_difference.py index feff1b0..6fee49f 100644 --- a/znnl/distance_metrics/order_n_difference.py +++ b/znnl/distance_metrics/order_n_difference.py @@ -79,7 +79,7 @@ def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs): return np.sum(np.power(diff, self.order), axis=1) else: raise ValueError(f"Invalid reduction operation: {self.reduce_operation}") - + def __name__(self): """ Name of the class. @@ -89,4 +89,4 @@ def __name__(self): name : str The name of the class. """ - return f"order_{self.order}_difference_{self.reduce_operation}" + return "order_{self.order}_difference_{self.reduce_operation}" diff --git a/znnl/loss_functions/__init__.py b/znnl/loss_functions/__init__.py index 230d8e5..0ab9abb 100644 --- a/znnl/loss_functions/__init__.py +++ b/znnl/loss_functions/__init__.py @@ -28,9 +28,9 @@ from znnl.loss_functions.cosine_distance import CosineDistanceLoss from znnl.loss_functions.cross_entropy_loss import CrossEntropyLoss from znnl.loss_functions.l_p_norm import LPNormLoss +from znnl.loss_functions.loss import Loss from znnl.loss_functions.mahalanobis import MahalanobisLoss from znnl.loss_functions.mean_power_error import MeanPowerLoss -from znnl.loss_functions.loss import Loss __all__ = [ AngleDistanceLoss.__name__, diff --git a/znnl/loss_functions/absolute_angle_difference.py b/znnl/loss_functions/absolute_angle_difference.py index e195476..2db6021 100644 --- a/znnl/loss_functions/absolute_angle_difference.py +++ b/znnl/loss_functions/absolute_angle_difference.py @@ -41,12 +41,12 @@ def __init__(self): self.metric = AngularDistance() def __name__(self): - """ - Name of the class. - - Returns - ------- - name : str - The name of the class. - """ - return f"angle_distance_loss" + """ + Name of the class. + + Returns + ------- + name : str + The name of the class. + """ + return "angle_distance_loss" diff --git a/znnl/loss_functions/cosine_distance.py b/znnl/loss_functions/cosine_distance.py index ab443fc..a661f0b 100644 --- a/znnl/loss_functions/cosine_distance.py +++ b/znnl/loss_functions/cosine_distance.py @@ -49,4 +49,4 @@ def __name__(self): name : str The name of the class. """ - return f"cosine_distance_loss" + return "cosine_distance_loss" diff --git a/znnl/loss_functions/cross_entropy_loss.py b/znnl/loss_functions/cross_entropy_loss.py index 2c92ad0..67df0e9 100644 --- a/znnl/loss_functions/cross_entropy_loss.py +++ b/znnl/loss_functions/cross_entropy_loss.py @@ -24,10 +24,8 @@ Summary ------- """ -import optax - -from znnl.loss_functions.loss import Loss from znnl.distance_metrics.cross_entropy_distance import CrossEntropyDistance +from znnl.loss_functions.loss import Loss class CrossEntropyLoss(Loss): @@ -51,4 +49,4 @@ def __name__(self): name : str The name of the class. """ - return f"cross_entropy_loss" + return "cross_entropy_loss" diff --git a/znnl/loss_functions/loss.py b/znnl/loss_functions/loss.py index e5dc441..40c2781 100644 --- a/znnl/loss_functions/loss.py +++ b/znnl/loss_functions/loss.py @@ -56,7 +56,7 @@ def __name__(self) -> str: The name of the class. """ return "loss_parent" - + def __signature__(self) -> tuple: """ Signature of the class. diff --git a/znnl/loss_functions/mahalanobis.py b/znnl/loss_functions/mahalanobis.py index 281d266..3723668 100644 --- a/znnl/loss_functions/mahalanobis.py +++ b/znnl/loss_functions/mahalanobis.py @@ -49,4 +49,4 @@ def __name__(self): name : str The name of the class. """ - return f"mahalanobis_loss" + return "mahalanobis_loss" diff --git a/znnl/loss_functions/mean_power_error.py b/znnl/loss_functions/mean_power_error.py index 4b9e6c6..d5a5e62 100644 --- a/znnl/loss_functions/mean_power_error.py +++ b/znnl/loss_functions/mean_power_error.py @@ -54,4 +54,4 @@ def __name__(self): name : str The name of the class. """ - return f"mean_power_loss_{self.order}" + return "mean_power_loss_{self.order}" diff --git a/znnl/observables/__init__.py b/znnl/observables/__init__.py index 7ac9b0d..832d9c4 100644 --- a/znnl/observables/__init__.py +++ b/znnl/observables/__init__.py @@ -27,6 +27,4 @@ """ from znnl.observables.observable import Observable -__all__ = [ - Observable.__name__ -] +__all__ = [Observable.__name__] diff --git a/znnl/observables/observable.py b/znnl/observables/observable.py index 8dd5bd6..9b9a8c3 100644 --- a/znnl/observables/observable.py +++ b/znnl/observables/observable.py @@ -46,7 +46,7 @@ def __name__(self) -> str: The name of the observable. """ raise NotImplementedError("Implemented in child class.") - + @classmethod def __signature__(self, data_set: dict) -> tuple: """ @@ -63,7 +63,7 @@ def __signature__(self, data_set: dict) -> tuple: The signature of the observable. """ raise NotImplementedError("Implemented in child class.") - + @classmethod def __call__(self, data_set: dict) -> Union[str, np.ndarray, float]: """ @@ -78,6 +78,6 @@ def __call__(self, data_set: dict) -> Union[str, np.ndarray, float]: ------- value : Union[str, np.ndarray, float] The value of the observable. - + """ raise NotImplementedError("Implemented in child class.") diff --git a/znnl/training_strategies/partitioned_training.py b/znnl/training_strategies/partitioned_training.py index fdab899..93eeb43 100644 --- a/znnl/training_strategies/partitioned_training.py +++ b/znnl/training_strategies/partitioned_training.py @@ -88,8 +88,8 @@ def __init__( Random seed for the RNG. Uses a random int if not specified. recursive_mode : RecursiveMode Defining the recursive mode that can be used in training. - If the recursive mode is used, the training will be performed until a - condition is fulfilled. + If the recursive mode is used, the training will be performed + until a condition is fulfilled. The loss value at which point you consider the model trained. disable_loading_bar : bool Disable the output visualization of the loading bar. @@ -244,7 +244,7 @@ def train_model( Number of epochs to train over. Each epoch defines a training phase. train_ds_selection : list - (default = [slice(-1, None, None), slice(None, None, None)]) + default = [slice(-1, None, None), slice(None, None, None)] The train is selected by a np.array of indices or slices. Each slice or array defines a training phase. batch_size : list (default = [1, 1])