Skip to content

Commit

Permalink
raise a TypeError is an object of the wrong type is added to a contai…
Browse files Browse the repository at this point in the history
…ner.
  • Loading branch information
apdavison committed Jul 27, 2023
1 parent da4d0ec commit 74b5aa4
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 5 deletions.
18 changes: 16 additions & 2 deletions neo/core/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,22 @@ def _get_container(self, cls):
def add(self, *objects):
"""Add a new Neo object to the Container"""
for obj in objects:
container = self._get_container(obj.__class__)
container.append(obj)
if (
obj.__class__.__name__ in self._child_objects
or (
hasattr(obj, "proxy_for")
and obj.proxy_for.__name__ in self._child_objects
)
):
container = self._get_container(obj.__class__)
container.append(obj)
else:
raise TypeError(
f"Cannot add object of type {obj.__class__.__name__} "
f"to a {self.__class__.__name__}, can only add objects of the "
f"following types: {self._child_objects}"
)



def filter(self, targdict=None, data=True, container=False, recursive=True,
Expand Down
6 changes: 4 additions & 2 deletions neo/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def __init__(self, objects=None, name=None, description=None, file_origin=None,
self.allowed_types = None
else:
self.allowed_types = tuple(allowed_types)
for type_ in self.allowed_types:
if type_.__name__ not in self._child_objects:
raise TypeError(f"Groups can not contain objects of type {type_.__name__}")

if objects:
self.add(*objects)
Expand Down Expand Up @@ -140,8 +143,7 @@ def add(self, *objects):
if self.allowed_types and not isinstance(obj, self.allowed_types):
raise TypeError("This Group can only contain {}, but not {}"
"".format(self.allowed_types, type(obj)))
container = self._get_container(obj.__class__)
container.append(obj)
super().add(*objects)

def walk(self):
"""
Expand Down
6 changes: 5 additions & 1 deletion neo/test/coretest/test_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from neo.core import SpikeTrain, AnalogSignal, Event
from neo.test.tools import (assert_neo_object_is_compliant,
assert_same_sub_schema)
from neo.test.generate_datasets import random_block, simple_block
from neo.test.generate_datasets import random_block, simple_block, random_signal


N_EXAMPLES = 5
Expand Down Expand Up @@ -493,6 +493,10 @@ def test_add(self):
new_blk.add(*blk.segments)
assert len(new_blk.segments) == n_segs_start + len(blk.segments)

def test_add_invalid_type_raises_Exception(self):
new_blk = Block()
self.assertRaises(TypeError, new_blk.add, random_signal())


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions neo/test/coretest/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from neo.core.segment import Segment
from neo.core.view import ChannelView
from neo.core.group import Group
from neo.core.block import Block


class TestGroup(unittest.TestCase):
Expand Down Expand Up @@ -91,3 +92,9 @@ def test_walk(self):
target.extend([children[1], children[2], *grandchildren[2]])
self.assertEqual(flattened,
target)

def test_add_invalid_type_raises_Exception(self):
group = Group()
self.assertRaises(TypeError, group.add, Block())

self.assertRaises(TypeError, Group, allowed_types=[Block])
4 changes: 4 additions & 0 deletions neo/test/coretest/test_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,10 @@ def test_add(self):
seg.add(proxy_epoch)
assert len(seg.epochs) == 1

def test_add_invalid_type_raises_Exception(self):
seg = Segment()
self.assertRaises(TypeError, seg.add, Block())


if __name__ == "__main__":
unittest.main()

0 comments on commit 74b5aa4

Please sign in to comment.