From f12e7af65ef14baec63c199af9a7e69a403b3c04 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 11 Sep 2024 14:30:49 +0200 Subject: [PATCH] fix: compatibility audio do with new `scipy` (#2733) * compatibility audio do with new `scipy` * smaller array to fix torch.unique case --------- Co-authored-by: Nicki Skafte Detlefsen --- CHANGELOG.md | 3 +++ src/torchmetrics/__init__.py | 7 +++++++ src/torchmetrics/functional/nominal/__init__.py | 1 + src/torchmetrics/nominal/__init__.py | 1 + src/torchmetrics/utilities/imports.py | 1 + tests/unittests/classification/test_stat_scores.py | 4 ++-- 6 files changed, 15 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4bac68f736f..0fc6c936492 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Correct the padding related calculation errors in SSIM ([#2721](https://github.com/Lightning-AI/torchmetrics/pull/2721)) +- Fixed compatibility of audio domain with new `scipy` ([#2733](https://github.com/Lightning-AI/torchmetrics/pull/2733)) + + - Fixed how `prefix`/`postfix` works in `MultitaskWrapper` ([#2722](https://github.com/Lightning-AI/torchmetrics/pull/2722)) diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index b1549dfaf8b..2fa370cb1c9 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -20,6 +20,13 @@ if not hasattr(PIL, "PILLOW_VERSION"): PIL.PILLOW_VERSION = PIL.__version__ +if package_available("scipy"): + import scipy.signal + + # back compatibility patch due to SMRMpy using scipy.signal.hamming + if not hasattr(scipy.signal, "hamming"): + scipy.signal.hamming = scipy.signal.windows.hamming + from torchmetrics import functional # noqa: E402 from torchmetrics.aggregation import ( # noqa: E402 CatMetric, diff --git a/src/torchmetrics/functional/nominal/__init__.py b/src/torchmetrics/functional/nominal/__init__.py index f29dd9302f0..772cb395895 100644 --- a/src/torchmetrics/functional/nominal/__init__.py +++ b/src/torchmetrics/functional/nominal/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix from torchmetrics.functional.nominal.fleiss_kappa import fleiss_kappa from torchmetrics.functional.nominal.pearson import ( diff --git a/src/torchmetrics/nominal/__init__.py b/src/torchmetrics/nominal/__init__.py index f23a7eb8c6b..e36da870308 100644 --- a/src/torchmetrics/nominal/__init__.py +++ b/src/torchmetrics/nominal/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from torchmetrics.nominal.cramers import CramersV from torchmetrics.nominal.fleiss_kappa import FleissKappa from torchmetrics.nominal.pearson import PearsonsContingencyCoefficient diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index b40a334558f..10affebf579 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -64,6 +64,7 @@ _MECAB_KO_DIC_AVAILABLE = RequirementCache("mecab_ko_dic") _IPADIC_AVAILABLE = RequirementCache("ipadic") _SENTENCEPIECE_AVAILABLE = RequirementCache("sentencepiece") +_SCIPI_AVAILABLE = RequirementCache("scipy") _SKLEARN_GREATER_EQUAL_1_3 = RequirementCache("scikit-learn>=1.3.0") _LATEX_AVAILABLE: bool = shutil.which("latex") is not None diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index 53fa78d0368..5ea4c206bc0 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -582,8 +582,8 @@ def test_support_for_int(): """See issue: https://github.com/Lightning-AI/torchmetrics/issues/1970.""" seed_all(42) metric = MulticlassStatScores(num_classes=4, average="none", multidim_average="samplewise", ignore_index=0) - prediction = torch.randint(low=0, high=4, size=(1, 224, 224)).to(torch.uint8) - label = torch.randint(low=0, high=4, size=(1, 224, 224)).to(torch.uint8) + prediction = torch.randint(low=0, high=4, size=(1, 50, 50)).to(torch.uint8) + label = torch.randint(low=0, high=4, size=(1, 50, 50)).to(torch.uint8) score = metric(preds=prediction, target=label) assert score.shape == (1, 4, 5)