diff --git a/escnn/kernels/steerable_basis.py b/escnn/kernels/steerable_basis.py index 5091f6fc..fa5513ee 100644 --- a/escnn/kernels/steerable_basis.py +++ b/escnn/kernels/steerable_basis.py @@ -5,6 +5,7 @@ from escnn.group import Group from escnn.group import IrreducibleRepresentation from escnn.group import Representation +from escnn.utils import unique_ever_seen import torch @@ -238,9 +239,9 @@ def __init__(self, js = set() # loop over all input irreps - for i_irrep_id in set(in_repr.irreps): + for i_irrep_id in unique_ever_seen(in_repr.irreps): # loop over all output irreps - for o_irrep_id in set(out_repr.irreps): + for o_irrep_id in unique_ever_seen(out_repr.irreps): try: # retrieve the irrep intertwiner basis intertwiner_basis = irreps_basis._generator(basis, i_irrep_id, o_irrep_id, **kwargs) diff --git a/escnn/kernels/wignereckart_solver.py b/escnn/kernels/wignereckart_solver.py index 2f9e877e..d14a6358 100644 --- a/escnn/kernels/wignereckart_solver.py +++ b/escnn/kernels/wignereckart_solver.py @@ -4,6 +4,7 @@ from .steerable_filters_basis import SteerableFiltersBasis from escnn.group import * +from escnn.utils import unique_ever_seen import torch @@ -302,7 +303,7 @@ def __init__(self, _js_restriction = defaultdict(list) # for each harmonic j' to consider - for _j in set(_j for _j, _ in basis.js): + for _j in unique_ever_seen(_j for _j, _ in basis.js): if basis.multiplicity(_j) == 0: continue diff --git a/escnn/nn/modules/basismanager/basisexpansion_blocks.py b/escnn/nn/modules/basismanager/basisexpansion_blocks.py index 311e1c19..f4da1d7b 100644 --- a/escnn/nn/modules/basismanager/basisexpansion_blocks.py +++ b/escnn/nn/modules/basismanager/basisexpansion_blocks.py @@ -2,7 +2,7 @@ from escnn.kernels import KernelBasis, EmptyBasisException from escnn.group import Representation from escnn.nn.modules import utils -from escnn.nn.modules.utils import unique_ever_seen +from escnn.utils import unique_ever_seen from .basismanager import BasisManager from .basisexpansion_singleblock import block_basisexpansion diff --git a/escnn/nn/modules/basismanager/basissampler_blocks.py b/escnn/nn/modules/basismanager/basissampler_blocks.py index 6c617e7a..8f9458a3 100644 --- a/escnn/nn/modules/basismanager/basissampler_blocks.py +++ b/escnn/nn/modules/basismanager/basissampler_blocks.py @@ -1,13 +1,13 @@ from escnn.group import Representation from escnn.kernels import KernelBasis, EmptyBasisException +from escnn.utils import unique_ever_seen from escnn.nn.modules.basismanager import retrieve_indices from .basismanager import BasisManager from escnn.nn.modules.basismanager.basissampler_singleblock import block_basissampler -from escnn.nn.modules.utils import unique_ever_seen from typing import Callable, Tuple, Dict, List, Iterable, Union from collections import defaultdict diff --git a/escnn/nn/modules/batchnormalization/gnorm.py b/escnn/nn/modules/batchnormalization/gnorm.py index 1a250426..09e7acae 100644 --- a/escnn/nn/modules/batchnormalization/gnorm.py +++ b/escnn/nn/modules/batchnormalization/gnorm.py @@ -5,7 +5,7 @@ from escnn.gspaces import * from escnn.nn import FieldType from escnn.nn import GeometricTensor -from escnn.nn.modules.utils import unique_ever_seen +from escnn.utils import unique_ever_seen from ..equivariant_module import EquivariantModule diff --git a/escnn/nn/modules/batchnormalization/iid.py b/escnn/nn/modules/batchnormalization/iid.py index 5d1039b2..a2779310 100644 --- a/escnn/nn/modules/batchnormalization/iid.py +++ b/escnn/nn/modules/batchnormalization/iid.py @@ -5,7 +5,7 @@ from escnn.gspaces import * from escnn.nn import FieldType from escnn.nn import GeometricTensor -from escnn.nn.modules.utils import unique_ever_seen +from escnn.utils import unique_ever_seen from ..equivariant_module import EquivariantModule diff --git a/escnn/nn/modules/utils.py b/escnn/nn/modules/utils.py index e1cd6ac9..03d7203a 100644 --- a/escnn/nn/modules/utils.py +++ b/escnn/nn/modules/utils.py @@ -1,6 +1,6 @@ from escnn.nn import FieldType -from typing import List, Dict, Tuple, Iterable +from typing import List, Dict, Tuple from collections import defaultdict @@ -55,14 +55,3 @@ def indexes_from_labels(in_type: FieldType, labels: List[str]) -> Dict[str, Tupl return groups - -def unique_ever_seen(iterable: Iterable) -> Iterable: - already_seen = set() - - for item in iterable: - if item in already_seen: - continue - else: - already_seen.add(item) - yield item - diff --git a/escnn/utils.py b/escnn/utils.py new file mode 100644 index 00000000..8078ab13 --- /dev/null +++ b/escnn/utils.py @@ -0,0 +1,13 @@ +from typing import Iterable + +def unique_ever_seen(iterable: Iterable) -> Iterable: + already_seen = set() + + for item in iterable: + if item in already_seen: + continue + else: + already_seen.add(item) + yield item + + diff --git a/test/nn/test_basisexpansion.py b/test/nn/test_basisexpansion.py index 1117dc77..51c02341 100644 --- a/test/nn/test_basisexpansion.py +++ b/test/nn/test_basisexpansion.py @@ -135,8 +135,8 @@ def compare(self, basis: BlocksBasisExpansion): for i, attr1 in enumerate(basis.get_basis_info()): attr2 = basis.get_element_info(i) - self.assertEquals(attr1, attr2) - self.assertEquals(attr1['id'], i) + self.assertEqual(attr1, attr2) + self.assertEqual(attr1['id'], i) for _ in range(5): w = torch.randn(basis.dimension()) @@ -144,8 +144,8 @@ def compare(self, basis: BlocksBasisExpansion): f1 = basis(w) f2 = basis(w) assert torch.allclose(f1, f2) - self.assertEquals(f1.shape[1], basis._input_size) - self.assertEquals(f1.shape[0], basis._output_size) + self.assertEqual(f1.shape[1], basis._input_size) + self.assertEqual(f1.shape[0], basis._output_size) def test_checkpoint_meshgrid(self): diff --git a/test/nn/test_basissampler.py b/test/nn/test_basissampler.py index 967ad47a..eaef8a6c 100644 --- a/test/nn/test_basissampler.py +++ b/test/nn/test_basissampler.py @@ -126,8 +126,8 @@ def compare(self, basis: BlocksBasisSampler, d: int): for i, attr1 in enumerate(basis.get_basis_info()): attr2 = basis.get_element_info(i) - self.assertEquals(attr1, attr2) - self.assertEquals(attr1['id'], i) + self.assertEqual(attr1, attr2) + self.assertEqual(attr1['id'], i) for _ in range(5): P = 20 @@ -150,8 +150,8 @@ def compare(self, basis: BlocksBasisSampler, d: int): f1 = basis(w, edge_delta) f2 = basis(w, edge_delta) self.assertTrue(torch.allclose(f1, f2)) - self.assertEquals(f1.shape[2], basis._input_size) - self.assertEquals(f1.shape[1], basis._output_size) + self.assertEqual(f1.shape[2], basis._input_size) + self.assertEqual(f1.shape[1], basis._output_size) y1 = basis.compute_messages(w, x_j, edge_delta, conv_first=False) y2 = basis.compute_messages(w, x_j, edge_delta, conv_first=True)