From a5dc69191999b10e0166b25d07b9c864db0d53f9 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Sat, 22 Feb 2025 19:21:13 +0100 Subject: [PATCH] fix: use numbers.Real for checking type np.float32 are not float for isinstance, let's use a more generic checking. --- src/modopt/base/types.py | 9 +++++---- src/modopt/opt/algorithms/base.py | 3 ++- src/modopt/signal/positivity.py | 5 +++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/modopt/base/types.py b/src/modopt/base/types.py index 9e9a15b9..a6b7fd37 100644 --- a/src/modopt/base/types.py +++ b/src/modopt/base/types.py @@ -6,6 +6,7 @@ """ +import numbers import numpy as np from modopt.interface.errors import warn @@ -68,14 +69,14 @@ def check_float(input_obj): check_int : related function """ - if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)): + if not isinstance(input_obj, (int, numbers.Real, list, tuple, np.ndarray)): raise TypeError("Invalid input type.") if isinstance(input_obj, int): input_obj = float(input_obj) elif isinstance(input_obj, (list, tuple)): input_obj = np.array(input_obj, dtype=float) elif isinstance(input_obj, np.ndarray) and ( - not np.issubdtype(input_obj.dtype, np.floating) + not np.issubdtype(input_obj.dtype, numbers.Real) ): input_obj = input_obj.astype(float) @@ -117,9 +118,9 @@ def check_int(input_obj): check_float : related function """ - if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)): + if not isinstance(input_obj, (int, numbers.Real, list, tuple, np.ndarray)): raise TypeError("Invalid input type.") - if isinstance(input_obj, float): + if isinstance(input_obj, numbers.Real): input_obj = int(input_obj) elif isinstance(input_obj, (list, tuple)): input_obj = np.array(input_obj, dtype=int) diff --git a/src/modopt/opt/algorithms/base.py b/src/modopt/opt/algorithms/base.py index f7391063..3927ee4d 100644 --- a/src/modopt/opt/algorithms/base.py +++ b/src/modopt/opt/algorithms/base.py @@ -3,6 +3,7 @@ from inspect import getmro import numpy as np +import numbers from tqdm.auto import tqdm from modopt.base import backend @@ -192,7 +193,7 @@ def _check_param(self, param_val): For invalid input type """ - if not isinstance(param_val, float): + if not isinstance(param_val, numbers.Real): raise TypeError("Algorithm parameter must be a float value.") def _check_param_update(self, param_update): diff --git a/src/modopt/signal/positivity.py b/src/modopt/signal/positivity.py index 8d7aa46c..7bde5265 100644 --- a/src/modopt/signal/positivity.py +++ b/src/modopt/signal/positivity.py @@ -7,6 +7,7 @@ """ +import numbers import numpy as np @@ -93,13 +94,13 @@ def positive(input_data, ragged=False): [1, 2, 3]]) """ - if not isinstance(input_data, (int, float, list, tuple, np.ndarray)): + if not isinstance(input_data, (int, numbers.Real, list, tuple, np.ndarray)): raise TypeError( "Invalid data type, input must be `int`, `float`, `list`, " + "`tuple` or `np.ndarray`.", ) - if isinstance(input_data, (int, float)): + if isinstance(input_data, (int, numbers.Real)): return pos_thresh(input_data) if ragged: