Skip to content

Commit

Permalink
fix: compatibility audio do with new scipy (#2733)
Browse files Browse the repository at this point in the history
* compatibility audio do with new `scipy`
* smaller array to fix torch.unique case

---------

Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
  • Loading branch information
Borda and SkafteNicki authored Sep 11, 2024
1 parent 80929b5 commit 96ceda0
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Correct the padding related calculation errors in SSIM ([#2721](https://github.com/Lightning-AI/torchmetrics/pull/2721))


- Fixed compatibility of audio domain with new `scipy` ([#2733](https://github.com/Lightning-AI/torchmetrics/pull/2733))


- Fixed how `prefix`/`postfix` works in `MultitaskWrapper` ([#2722](https://github.com/Lightning-AI/torchmetrics/pull/2722))


Expand Down
8 changes: 8 additions & 0 deletions src/torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,18 @@
_ONNXRUNTIME_AVAILABLE,
_PESQ_AVAILABLE,
_PYSTOI_AVAILABLE,
_SCIPI_AVAILABLE,
_TORCHAUDIO_AVAILABLE,
_TORCHAUDIO_GREATER_EQUAL_0_10,
)

if _SCIPI_AVAILABLE:
import scipy.signal

# back compatibility patch due to SMRMpy using scipy.signal.hamming
if not hasattr(scipy.signal, "hamming"):
scipy.signal.hamming = scipy.signal.windows.hamming

__all__ = [
"PermutationInvariantTraining",
"ScaleInvariantSignalDistortionRatio",
Expand Down
8 changes: 8 additions & 0 deletions src/torchmetrics/functional/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,18 @@
_ONNXRUNTIME_AVAILABLE,
_PESQ_AVAILABLE,
_PYSTOI_AVAILABLE,
_SCIPI_AVAILABLE,
_TORCHAUDIO_AVAILABLE,
_TORCHAUDIO_GREATER_EQUAL_0_10,
)

if _SCIPI_AVAILABLE:
import scipy.signal

# back compatibility patch due to SMRMpy using scipy.signal.hamming
if not hasattr(scipy.signal, "hamming"):
scipy.signal.hamming = scipy.signal.windows.hamming

__all__ = [
"permutation_invariant_training",
"pit_permutate",
Expand Down
9 changes: 9 additions & 0 deletions src/torchmetrics/functional/nominal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix
from torchmetrics.functional.nominal.fleiss_kappa import fleiss_kappa
from torchmetrics.functional.nominal.pearson import (
Expand All @@ -19,6 +20,14 @@
)
from torchmetrics.functional.nominal.theils_u import theils_u, theils_u_matrix
from torchmetrics.functional.nominal.tschuprows import tschuprows_t, tschuprows_t_matrix
from torchmetrics.utilities.imports import _SCIPI_AVAILABLE

if _SCIPI_AVAILABLE:
import scipy.signal

# back compatibility patch due to SMRMpy using scipy.signal.hamming
if not hasattr(scipy.signal, "hamming"):
scipy.signal.hamming = scipy.signal.windows.hamming

__all__ = [
"cramers_v",
Expand Down
9 changes: 9 additions & 0 deletions src/torchmetrics/nominal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torchmetrics.nominal.cramers import CramersV
from torchmetrics.nominal.fleiss_kappa import FleissKappa
from torchmetrics.nominal.pearson import PearsonsContingencyCoefficient
from torchmetrics.nominal.theils_u import TheilsU
from torchmetrics.nominal.tschuprows import TschuprowsT
from torchmetrics.utilities.imports import _SCIPI_AVAILABLE

if _SCIPI_AVAILABLE:
import scipy.signal

# back compatibility patch due to SMRMpy using scipy.signal.hamming
if not hasattr(scipy.signal, "hamming"):
scipy.signal.hamming = scipy.signal.windows.hamming

__all__ = [
"CramersV",
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
_MECAB_KO_DIC_AVAILABLE = RequirementCache("mecab_ko_dic")
_IPADIC_AVAILABLE = RequirementCache("ipadic")
_SENTENCEPIECE_AVAILABLE = RequirementCache("sentencepiece")
_SCIPI_AVAILABLE = RequirementCache("scipy")
_SKLEARN_GREATER_EQUAL_1_3 = RequirementCache("scikit-learn>=1.3.0")

_LATEX_AVAILABLE: bool = shutil.which("latex") is not None
4 changes: 2 additions & 2 deletions tests/unittests/classification/test_stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,8 +582,8 @@ def test_support_for_int():
"""See issue: https://github.com/Lightning-AI/torchmetrics/issues/1970."""
seed_all(42)
metric = MulticlassStatScores(num_classes=4, average="none", multidim_average="samplewise", ignore_index=0)
prediction = torch.randint(low=0, high=4, size=(1, 224, 224)).to(torch.uint8)
label = torch.randint(low=0, high=4, size=(1, 224, 224)).to(torch.uint8)
prediction = torch.randint(low=0, high=4, size=(1, 50, 50)).to(torch.uint8)
label = torch.randint(low=0, high=4, size=(1, 50, 50)).to(torch.uint8)
score = metric(preds=prediction, target=label)
assert score.shape == (1, 4, 5)

Expand Down

0 comments on commit 96ceda0

Please sign in to comment.