From 79208d376e5408d574fd7778135547172d6d98d7 Mon Sep 17 00:00:00 2001 From: Owen Hilyard Date: Mon, 20 Jan 2025 21:44:03 -0500 Subject: [PATCH 1/2] [stdlib] Expand LinkedList API and add additional testing Merges my earlier work with the one which appeared in the stdlib, expanding the API and test coverage. Also makes a few of the test_utils counters `Writable` to help them conform to more APIs. Signed-off-by: Owen Hilyard --- stdlib/src/collections/linked_list.mojo | 416 +++++++++++++-- stdlib/test/collections/test_linked_list.mojo | 503 +++++++++++++++++- stdlib/test/test_utils/types.mojo | 13 +- 3 files changed, 894 insertions(+), 38 deletions(-) diff --git a/stdlib/src/collections/linked_list.mojo b/stdlib/src/collections/linked_list.mojo index cb2e362749..51c2b51988 100644 --- a/stdlib/src/collections/linked_list.mojo +++ b/stdlib/src/collections/linked_list.mojo @@ -17,25 +17,29 @@ from collections._index_normalization import normalize_index @value -struct Node[ElementType: WritableCollectionElement]: +struct Node[ + ElementType: CollectionElement, +]: """A node in a linked list data structure. Parameters: ElementType: The type of element stored in the node. """ + alias NodePointer = UnsafePointer[Self] + var value: ElementType """The value stored in this node.""" - var prev: UnsafePointer[Node[ElementType]] + var prev: Self.NodePointer """The previous node in the list.""" - var next: UnsafePointer[Node[ElementType]] + var next: Self.NodePointer """The next node in the list.""" fn __init__( out self, owned value: ElementType, - prev: Optional[UnsafePointer[Node[ElementType]]], - next: Optional[UnsafePointer[Node[ElementType]]], + prev: Optional[Self.NodePointer], + next: Optional[Self.NodePointer], ): """Initialize a new Node with the given value and optional prev/next pointers. @@ -49,19 +53,29 @@ struct Node[ElementType: WritableCollectionElement]: self.prev = prev.value() if prev else __type_of(self.prev)() self.next = next.value() if next else __type_of(self.next)() - fn __str__(self) -> String: + fn __str__[ + ElementType: WritableCollectionElement + ](self: Node[ElementType]) -> String: """Convert this node's value to a string representation. + Parameters: + ElementType: Used to conditionally enable this function if + `ElementType` is `Writable`. + Returns: String representation of the node's value. """ - return String.write(self) + return String.write(self.value) @no_inline - fn write_to[W: Writer](self, mut writer: W): + fn write_to[ + ElementType: WritableCollectionElement, W: Writer + ](self: Node[ElementType], mut writer: W): """Write this node's value to the given writer. Parameters: + ElementType: Used to conditionally enable this function if + `ElementType` is `Writable`. W: The type of writer to write the value to. Args: @@ -70,7 +84,9 @@ struct Node[ElementType: WritableCollectionElement]: writer.write(self.value) -struct LinkedList[ElementType: WritableCollectionElement]: +struct LinkedList[ + ElementType: CollectionElement, +]: """A doubly-linked list implementation. A doubly-linked list is a data structure where each element points to both @@ -79,12 +95,14 @@ struct LinkedList[ElementType: WritableCollectionElement]: Parameters: ElementType: The type of elements stored in the list. Must implement - WritableCollectionElement. + CollectionElement. """ - var _head: UnsafePointer[Node[ElementType]] + alias NodePointer = UnsafePointer[Node[ElementType]] + + var _head: Self.NodePointer """The first node in the list.""" - var _tail: UnsafePointer[Node[ElementType]] + var _tail: Self.NodePointer """The last node in the list.""" var _size: Int """The number of elements in the list.""" @@ -104,30 +122,44 @@ struct LinkedList[ElementType: WritableCollectionElement]: self = Self(elements=elements^) fn __init__(out self, *, owned elements: VariadicListMem[ElementType, _]): - """Initialize a linked list with the given elements. + """ + Construct a list from a `VariadicListMem`. Args: - elements: Variable number of elements to initialize the list with. + elements: The elements to add to the list. """ self = Self() - for elem in elements: - self.append(elem[]) + var length = len(elements) + + for i in range(length): + var src = UnsafePointer.address_of(elements[i]) + var node = Self.NodePointer.alloc(1) + var dst = UnsafePointer.address_of(node[].value) + src.move_pointee_into(dst) + node[].next = Self.NodePointer() + node[].prev = self._tail + if not self._tail: + self._head = node + self._tail = node + else: + self._tail[].next = node + self._tail = node # Do not destroy the elements when their backing storage goes away. __mlir_op.`lit.ownership.mark_destroyed`( __get_mvalue_as_litref(elements) ) + self._size = length + fn __copyinit__(mut self, read other: Self): """Initialize this list as a copy of another list. Args: other: The list to copy from. """ - self._head = other._head - self._tail = other._tail - self._size = other._size + self = other.copy() fn __moveinit__(mut self, owned other: Self): """Initialize this list by moving elements from another list. @@ -157,10 +189,12 @@ struct LinkedList[ElementType: WritableCollectionElement]: Args: value: The value to append. """ - var node = Node(value^, self._tail, None) - var addr = UnsafePointer[__type_of(node)].alloc(1) - addr.init_pointee_move(node) - if self: + var addr = Self.NodePointer.alloc(1) + var value_ptr = UnsafePointer.address_of(addr[].value) + value_ptr.init_pointee_move(value^) + addr[].prev = self._tail + addr[].next = Self.NodePointer() + if self._tail: self._tail[].next = addr else: self._head = addr @@ -195,20 +229,138 @@ struct LinkedList[ElementType: WritableCollectionElement]: self._tail = self._head self._head = prev - fn pop(mut self) -> ElementType: - """Remove and return the first element of the list. + fn pop(mut self) raises -> ElementType: + """Remove and return the last element of the list. + + Returns: + The last element in the list. + """ + var elem = self._tail + if not elem: + raise "Pop on empty list." + + var value = elem[].value + self._tail = elem[].prev + self._size -= 1 + if self._size == 0: + self._head = __type_of(self._head)() + else: + self._tail[].next = Self.NodePointer() + elem.free() + return value^ + + fn pop[I: Indexer](mut self, owned i: I) raises -> ElementType: + """ + Remove the ith element of the list, counting from the tail if + given a negative index. + + Parameters: + I: The type of index to use. + + Args: + i: The index of the element to get. + + Returns: + Ownership of the indicated element. + """ + var current = self._get_node_ptr(Int(i)) + + if not current: + raise "Invalid index for pop" + else: + var node = current[] + if node.prev: + node.prev[].next = node.next + else: + self._head = node.next + if node.next: + node.next[].prev = node.prev + else: + self._tail = node.prev + + var data = node.value^ + + # Aside from T, destructor is trivial + __mlir_op.`lit.ownership.mark_destroyed`( + __get_mvalue_as_litref(node) + ) + current.free() + self._size -= 1 + return data^ + + fn pop_if_present(mut self) -> Optional[ElementType]: + """Removes the head of the list and returns it, if it exists. Returns: - The first element in the list. + The head of the list, if it was present. """ var elem = self._tail + if not elem: + return Optional[ElementType]() var value = elem[].value self._tail = elem[].prev self._size -= 1 if self._size == 0: self._head = __type_of(self._head)() + else: + self._tail[].next = Self.NodePointer() + elem.free() return value^ + fn pop_if_present[ + I: Indexer + ](mut self, owned i: I) -> Optional[ElementType]: + """ + Remove the ith element of the list, counting from the tail if + given a negative index. + + Parameters: + I: The type of index to use. + + Args: + i: The index of the element to get. + + Returns: + The element, if it was found. + """ + var current = self._get_node_ptr(Int(i)) + + if not current: + return Optional[ElementType]() + else: + var node = current[] + if node.prev: + node.prev[].next = node.next + else: + self._head = node.next + if node.next: + node.next[].prev = node.prev + else: + self._tail = node.prev + + var data = node.value^ + + # Aside from T, destructor is trivial + __mlir_op.`lit.ownership.mark_destroyed`( + __get_mvalue_as_litref(node) + ) + current.free() + self._size -= 1 + return Optional[ElementType](data^) + + fn clear(mut self): + """Removes all elements from the list.""" + var current = self._head + while current: + var old = current + current = current[].next + old.destroy_pointee() + old.free() + + self._head = Self.NodePointer() + self._tail = Self.NodePointer() + self._size = 0 + fn copy(self) -> Self: """Create a deep copy of the list. @@ -222,6 +374,180 @@ struct LinkedList[ElementType: WritableCollectionElement]: curr = curr[].next return new^ + fn insert(mut self, owned idx: Int, owned elem: ElementType) raises: + """ + Insert an element `elem` into the list at index `idx`. + + Args: + idx: The index to insert `elem` at. + elem: The item to insert into the list. + """ + var i = max(0, index(idx) if idx >= 0 else index(idx) + len(self)) + + if i == 0: + var node = Self.NodePointer.alloc(1) + if not node: + raise "OOM" + node.init_pointee_move( + Node[ElementType](elem^, Self.NodePointer(), Self.NodePointer()) + ) + + if self._head: + node[].next = self._head + self._head[].prev = node + + self._head = node + + if not self._tail: + self._tail = node + + self._size += 1 + return + + i -= 1 + + var current = self._get_node_ptr(i) + if current: + var next = current[].next + var node = Self.NodePointer.alloc(1) + if not node: + raise "OOM" + var data = UnsafePointer.address_of(node[].value) + data[] = elem^ + node[].next = next + node[].prev = current + if next: + next[].prev = node + current[].next = node + if node[].next == Self.NodePointer(): + self._tail = node + if node[].prev == Self.NodePointer(): + self._head = node + self._size += 1 + else: + raise "index out of bounds" + + fn extend(mut self, owned other: Self): + """ + Extends the list with another. + O(1) time complexity. + + Args: + other: The list to append to this one. + """ + if self._tail: + self._tail[].next = other._head + if other._head: + other._head[].prev = self._tail + if other._tail: + self._tail = other._tail + + self._size += other._size + else: + self._head = other._head + self._tail = other._tail + self._size = other._size + + other._head = Self.NodePointer() + other._tail = Self.NodePointer() + + fn count[ + ElementType: EqualityComparableCollectionElement + ](self: LinkedList[ElementType], read elem: ElementType) -> UInt: + """ + Count the occurrences of `elem` in the list. + + Parameters: + ElementType: The list element type, used to conditionally enable the function. + + Args: + elem: The element to search for. + + Returns: + The number of occurrences of `elem` in the list. + """ + var current = self._head + var count = 0 + while current: + if current[].value == elem: + count += 1 + + current = current[].next + + return count + + fn __contains__[ + ElementType: EqualityComparableCollectionElement, // + ](self: LinkedList[ElementType], value: ElementType) -> Bool: + """ + Checks if the list contains `value`. + + Parameters: + ElementType: The list element type, used to conditionally enable the function. + + Args: + value: The value to search for in the list. + + Returns: + Whether the list contains `value`. + """ + var current = self._head + while current: + if current[].value == value: + return True + current = current[].next + + return False + + fn __eq__[ + ElementType: EqualityComparableCollectionElement, // + ]( + read self: LinkedList[ElementType], read other: LinkedList[ElementType] + ) -> Bool: + """ + Checks if the two lists are equal. + + Parameters: + ElementType: The list element type, used to conditionally enable the function. + + Args: + other: The list to compare to. + + Returns: + Whether the lists are equal. + """ + if self._size != other._size: + return False + + var self_cursor = self._head + var other_cursor = other._head + + while self_cursor: + if self_cursor[].value != other_cursor[].value: + return False + + self_cursor = self_cursor[].next + other_cursor = other_cursor[].next + + return True + + fn __ne__[ + ElementType: EqualityComparableCollectionElement, // + ](self: LinkedList[ElementType], other: LinkedList[ElementType]) -> Bool: + """ + Checks if the two lists are not equal. + + Parameters: + ElementType: The list element type, used to conditionally enable the function. + + Args: + other: The list to compare to. + + Returns: + Whether the lists are not equal. + """ + return not (self == other) + fn _get_node_ptr(ref self, index: Int) -> UnsafePointer[Node[ElementType]]: """Get a pointer to the node at the specified index. @@ -287,17 +613,31 @@ struct LinkedList[ElementType: WritableCollectionElement]: """ return len(self) != 0 - fn __str__(self) -> String: + fn __str__[ + ElementType: WritableCollectionElement + ](self: LinkedList[ElementType]) -> String: """Convert the list to its string representation. + Parameters: + ElementType: Used to conditionally enable this function when + `ElementType` is `Writable`. + Returns: String representation of the list. """ - return String.write(self) + var writer = String() + self._write(writer) + return writer - fn __repr__(self) -> String: + fn __repr__[ + ElementType: WritableCollectionElement + ](self: LinkedList[ElementType]) -> String: """Convert the list to its string representation. + Parameters: + ElementType: Used to conditionally enable this function when + `ElementType` is `Writable`. + Returns: String representation of the list. """ @@ -305,11 +645,15 @@ struct LinkedList[ElementType: WritableCollectionElement]: self._write(writer, prefix="LinkedList(", suffix=")") return writer - fn write_to[W: Writer](self, mut writer: W): + fn write_to[ + W: Writer, ElementType: WritableCollectionElement + ](self: LinkedList[ElementType], mut writer: W): """Write the list to the given writer. Parameters: W: The type of writer to write the list to. + ElementType: Used to conditionally enable this function when + `ElementType` is `Writable`. Args: writer: The writer to write the list to. @@ -318,8 +662,14 @@ struct LinkedList[ElementType: WritableCollectionElement]: @no_inline fn _write[ - W: Writer - ](self, mut writer: W, *, prefix: String = "[", suffix: String = "]"): + W: Writer, ElementType: WritableCollectionElement + ]( + self: LinkedList[ElementType], + mut writer: W, + *, + prefix: String = "[", + suffix: String = "]", + ): if not self: return writer.write(prefix, suffix) @@ -328,6 +678,6 @@ struct LinkedList[ElementType: WritableCollectionElement]: for i in range(len(self)): if i: writer.write(", ") - writer.write(curr[]) + writer.write(curr[].value) curr = curr[].next writer.write(suffix) diff --git a/stdlib/test/collections/test_linked_list.mojo b/stdlib/test/collections/test_linked_list.mojo index a5f3467736..6af2d0c4e4 100644 --- a/stdlib/test/collections/test_linked_list.mojo +++ b/stdlib/test/collections/test_linked_list.mojo @@ -12,8 +12,9 @@ # ===----------------------------------------------------------------------=== # # RUN: %mojo %s -from collections import LinkedList -from testing import assert_equal +from collections import LinkedList, Optional +from testing import assert_equal, assert_raises, assert_true, assert_false +from test_utils import CopyCounter, MoveCounter def test_construction(): @@ -101,12 +102,487 @@ def test_setitem(): def test_str(): var l1 = LinkedList[Int](1, 2, 3) - assert_equal(String(l1), "[1, 2, 3]") + assert_equal(l1.__str__(), "[1, 2, 3]") def test_repr(): var l1 = LinkedList[Int](1, 2, 3) - assert_equal(repr(l1), "LinkedList(1, 2, 3)") + assert_equal(l1.__repr__(), "LinkedList(1, 2, 3)") + + +def test_pop_on_empty_list(): + with assert_raises(): + var ll = LinkedList[Int]() + _ = ll.pop() + + +def test_optional_pop_on_empty_linked_list(): + var ll = LinkedList[Int]() + var result = ll.pop_if_present() + assert_false(Bool(result)) + + +def test_list(): + var list = LinkedList[Int]() + + for i in range(5): + list.append(i) + + assert_equal(5, len(list)) + assert_equal(0, list[0]) + assert_equal(1, list[1]) + assert_equal(2, list[2]) + assert_equal(3, list[3]) + assert_equal(4, list[4]) + + assert_equal(0, list[-5]) + assert_equal(3, list[-2]) + assert_equal(4, list[-1]) + + list[2] = -2 + assert_equal(-2, list[2]) + + list[-5] = 5 + assert_equal(5, list[-5]) + list[-2] = 3 + assert_equal(3, list[-2]) + list[-1] = 7 + assert_equal(7, list[-1]) + + +def test_list_clear(): + var list = LinkedList[Int](1, 2, 3) + assert_equal(len(list), 3) + list.clear() + + assert_equal(len(list), 0) + + +def test_list_to_bool_conversion(): + assert_false(LinkedList[String]()) + assert_true(LinkedList[String]("a")) + assert_true(LinkedList[String]("", "a")) + assert_true(LinkedList[String]("")) + + +def test_list_pop(): + var list = LinkedList[Int]() + # Test pop with index + for i in range(6): + list.append(i) + + assert_equal(6, len(list)) + + # try popping from index 3 for 3 times + for i in range(3, 6): + assert_equal[Int](i, list.pop(3)) + + # list should have 3 elements now + assert_equal(3, len(list)) + assert_equal(0, list[0]) + assert_equal(1, list[1]) + assert_equal(2, list[2]) + + # Test pop with negative index + for i in range(0, 2): + var popped: Int = list.pop(-len(list)) + assert_equal(i, popped) + + # test default index as well + assert_equal(2, list.pop()) + list.append(2) + assert_equal(2, list.pop()) + + # list should be empty now + assert_equal(0, len(list)) + + +def test_list_variadic_constructor(): + var l = LinkedList[Int](2, 4, 6) + assert_equal(3, len(l)) + assert_equal(2, l[0]) + assert_equal(4, l[1]) + assert_equal(6, l[2]) + + l.append(8) + assert_equal(4, len(l)) + assert_equal(8, l[3]) + + # + # Test variadic construct copying behavior + # + + var l2 = LinkedList[CopyCounter]( + CopyCounter(), CopyCounter(), CopyCounter() + ) + + assert_equal(len(l2), 3) + assert_equal(l2[0].copy_count, 0) + assert_equal(l2[1].copy_count, 0) + assert_equal(l2[2].copy_count, 0) + + +def test_list_reverse(): + # + # Test reversing the list [] + # + + var vec = LinkedList[Int]() + + assert_equal(len(vec), 0) + + vec.reverse() + + assert_equal(len(vec), 0) + + # + # Test reversing the list [123] + # + + vec = LinkedList[Int]() + + vec.append(123) + + assert_equal(len(vec), 1) + assert_equal(vec[0], 123) + + vec.reverse() + + assert_equal(len(vec), 1) + assert_equal(vec[0], 123) + + # + # Test reversing the list ["one", "two", "three"] + # + + var vec2 = LinkedList[String]("one", "two", "three") + + assert_equal(len(vec2), 3) + assert_equal(vec2[0], "one") + assert_equal(vec2[1], "two") + assert_equal(vec2[2], "three") + + vec2.reverse() + + assert_equal(len(vec2), 3) + assert_equal(vec2[0], "three") + assert_equal(vec2[1], "two") + assert_equal(vec2[2], "one") + + # + # Test reversing the list [5, 10] + # + + vec = LinkedList[Int]() + vec.append(5) + vec.append(10) + + assert_equal(len(vec), 2) + assert_equal(vec[0], 5) + assert_equal(vec[1], 10) + + vec.reverse() + + assert_equal(len(vec), 2) + assert_equal(vec[0], 10) + assert_equal(vec[1], 5) + + +def test_list_insert(): + # + # Test the list [1, 2, 3] created with insert + # + + var v1 = LinkedList[Int]() + v1.insert(len(v1), 1) + v1.insert(len(v1), 3) + v1.insert(1, 2) + + assert_equal(len(v1), 3) + assert_equal(v1[0], 1) + assert_equal(v1[1], 2) + assert_equal(v1[2], 3) + + print(v1.__str__()) + + # + # Test the list [1, 2, 3, 4, 5] created with negative and positive index + # + + var v2 = LinkedList[Int]() + v2.insert(-1729, 2) + v2.insert(len(v2), 3) + v2.insert(len(v2), 5) + v2.insert(-1, 4) + v2.insert(-len(v2), 1) + print(v2.__str__()) + + assert_equal(len(v2), 5) + assert_equal(v2[0], 1) + assert_equal(v2[1], 2) + assert_equal(v2[2], 3) + assert_equal(v2[3], 4) + assert_equal(v2[4], 5) + + # + # Test the list [1, 2, 3, 4] created with negative index + # + + var v3 = LinkedList[Int]() + v3.insert(-11, 4) + v3.insert(-13, 3) + v3.insert(-17, 2) + v3.insert(-19, 1) + + assert_equal(len(v3), 4) + assert_equal(v3[0], 1) + assert_equal(v3[1], 2) + assert_equal(v3[2], 3) + assert_equal(v3[3], 4) + + # + # Test the list [1, 2, 3, 4, 5, 6, 7, 8] created with insert + # + + var v4 = LinkedList[Int]() + for i in range(4): + v4.insert(0, 4 - i) + v4.insert(len(v4), 4 + i + 1) + + for i in range(len(v4)): + assert_equal(v4[i], i + 1) + + +def test_list_extend_non_trivial(): + # Tests three things: + # - extend() for non-plain-old-data types + # - extend() with mixed-length self and other lists + # - extend() using optimal number of __moveinit__() calls + + # Preallocate with enough capacity to avoid reallocation making the + # move count checks below flaky. + var v1 = LinkedList[MoveCounter[String]]() + v1.append(MoveCounter[String]("Hello")) + v1.append(MoveCounter[String]("World")) + + var v2 = LinkedList[MoveCounter[String]]() + v2.append(MoveCounter[String]("Foo")) + v2.append(MoveCounter[String]("Bar")) + v2.append(MoveCounter[String]("Baz")) + + v1.extend(v2^) + + assert_equal(len(v1), 5) + assert_equal(v1[0].value, "Hello") + assert_equal(v1[1].value, "World") + assert_equal(v1[2].value, "Foo") + assert_equal(v1[3].value, "Bar") + assert_equal(v1[4].value, "Baz") + + assert_equal(v1[0].move_count, 1) + assert_equal(v1[1].move_count, 1) + assert_equal(v1[2].move_count, 1) + assert_equal(v1[3].move_count, 1) + assert_equal(v1[4].move_count, 1) + + +def test_2d_dynamic_list(): + var list = LinkedList[LinkedList[Int]]() + + for i in range(2): + var v = LinkedList[Int]() + for j in range(3): + v.append(i + j) + list.append(v) + + assert_equal(0, list[0][0]) + assert_equal(1, list[0][1]) + assert_equal(2, list[0][2]) + assert_equal(1, list[1][0]) + assert_equal(2, list[1][1]) + assert_equal(3, list[1][2]) + + assert_equal(2, len(list)) + + assert_equal(3, len(list[0])) + + list[0].clear() + assert_equal(0, len(list[0])) + + list.clear() + assert_equal(0, len(list)) + + +def test_list_explicit_copy(): + var list = LinkedList[CopyCounter]() + list.append(CopyCounter()) + var list_copy = list.copy() + assert_equal(0, list[0].copy_count) + assert_equal(1, list_copy[0].copy_count) + + var l2 = LinkedList[Int]() + for i in range(10): + l2.append(i) + + var l2_copy = l2.copy() + assert_equal(len(l2), len(l2_copy)) + for i in range(len(l2)): + assert_equal(l2[i], l2_copy[i]) + + +@value +struct CopyCountedStruct(CollectionElement): + var counter: CopyCounter + var value: String + + fn __init__(out self, *, other: Self): + self.counter = other.counter.copy() + self.value = other.value.copy() + + @implicit + fn __init__(out self, value: String): + self.counter = CopyCounter() + self.value = value + + +def test_no_extra_copies_with_sugared_set_by_field(): + var list = LinkedList[LinkedList[CopyCountedStruct]]() + var child_list = LinkedList[CopyCountedStruct]() + child_list.append(CopyCountedStruct("Hello")) + child_list.append(CopyCountedStruct("World")) + + # No copies here. Constructing with LinkedList[CopyCountedStruct](CopyCountedStruct("Hello")) is a copy. + assert_equal(0, child_list[0].counter.copy_count) + assert_equal(0, child_list[1].counter.copy_count) + + list.append(child_list^) + + assert_equal(0, list[0][0].counter.copy_count) + assert_equal(0, list[0][1].counter.copy_count) + + # list[0][1] makes a copy for reasons I cannot determine + list.__getitem__(0).__getitem__(1).value = "Mojo" + + assert_equal(0, list[0][0].counter.copy_count) + assert_equal(0, list[0][1].counter.copy_count) + + assert_equal("Mojo", list[0][1].value) + + assert_equal(0, list[0][0].counter.copy_count) + assert_equal(0, list[0][1].counter.copy_count) + + +def test_list_boolable(): + assert_true(LinkedList[Int](1)) + assert_false(LinkedList[Int]()) + + +def test_list_count(): + var list = LinkedList[Int](1, 2, 3, 2, 5, 6, 7, 8, 9, 10) + assert_equal(1, list.count(1)) + assert_equal(2, list.count(2)) + assert_equal(0, list.count(4)) + + var list2 = LinkedList[Int]() + assert_equal(0, list2.count(1)) + + +def test_list_contains(): + var x = LinkedList[Int](1, 2, 3) + assert_false(0 in x) + assert_true(1 in x) + assert_false(4 in x) + + # TODO: implement LinkedList.__eq__ for Self[ComparableCollectionElement] + # var y = LinkedList[LinkedList[Int]]() + # y.append(LinkedList(1,2)) + # assert_equal(LinkedList(1,2) in y,True) + # assert_equal(LinkedList(0,1) in y,False) + + +def test_list_eq_ne(): + var l1 = LinkedList[Int](1, 2, 3) + var l2 = LinkedList[Int](1, 2, 3) + assert_true(l1 == l2) + assert_false(l1 != l2) + + var l3 = LinkedList[Int](1, 2, 3, 4) + assert_false(l1 == l3) + assert_true(l1 != l3) + + var l4 = LinkedList[Int]() + var l5 = LinkedList[Int]() + assert_true(l4 == l5) + assert_true(l1 != l4) + + var l6 = LinkedList[String]("a", "b", "c") + var l7 = LinkedList[String]("a", "b", "c") + var l8 = LinkedList[String]("a", "b") + assert_true(l6 == l7) + assert_false(l6 != l7) + assert_false(l6 == l8) + + +def test_indexing(): + var l = LinkedList[Int](1, 2, 3) + assert_equal(l[Int(1)], 2) + assert_equal(l[False], 1) + assert_equal(l[True], 2) + assert_equal(l[2], 3) + + +# ===-------------------------------------------------------------------===# +# LinkedList dtor tests +# ===-------------------------------------------------------------------===# +var g_dtor_count: Int = 0 + + +struct DtorCounter(CollectionElement, Writable): + # NOTE: payload is required because LinkedList does not support zero sized structs. + var payload: Int + + fn __init__(out self): + self.payload = 0 + + fn __init__(out self, *, other: Self): + self.payload = other.payload + + fn __copyinit__(out self, existing: Self, /): + self.payload = existing.payload + + fn __moveinit__(out self, owned existing: Self, /): + self.payload = existing.payload + existing.payload = 0 + + fn __del__(owned self): + g_dtor_count += 1 + + fn write_to[W: Writer](self, mut writer: W): + writer.write("DtorCounter(") + writer.write(String(g_dtor_count)) + writer.write(")") + + +def inner_test_list_dtor(): + # explicitly reset global counter + g_dtor_count = 0 + + var l = LinkedList[DtorCounter]() + assert_equal(g_dtor_count, 0) + + l.append(DtorCounter()) + assert_equal(g_dtor_count, 0) + + l^.__del__() + assert_equal(g_dtor_count, 1) + + +def test_list_dtor(): + # call another function to force the destruction of the list + inner_test_list_dtor() + + # verify we still only ran the destructor once + assert_equal(g_dtor_count, 1) def main(): @@ -120,3 +596,22 @@ def main(): test_setitem() test_str() test_repr() + test_pop_on_empty_list() + test_optional_pop_on_empty_linked_list() + test_list() + test_list_clear() + test_list_to_bool_conversion() + test_list_pop() + test_list_variadic_constructor() + test_list_reverse() + test_list_extend_non_trivial() + test_list_explicit_copy() + test_no_extra_copies_with_sugared_set_by_field() + test_2d_dynamic_list() + test_list_boolable() + test_list_count() + test_list_contains() + test_indexing() + test_list_dtor() + test_list_insert() + test_list_eq_ne() diff --git a/stdlib/test/test_utils/types.mojo b/stdlib/test/test_utils/types.mojo index 01870acebc..0610304a36 100644 --- a/stdlib/test/test_utils/types.mojo +++ b/stdlib/test/test_utils/types.mojo @@ -88,7 +88,7 @@ struct ImplicitCopyOnly(Copyable): # ===----------------------------------------------------------------------=== # -struct CopyCounter(CollectionElement, ExplicitlyCopyable): +struct CopyCounter(CollectionElement, ExplicitlyCopyable, Writable): """Counts the number of copies performed on a value.""" var copy_count: Int @@ -108,6 +108,11 @@ struct CopyCounter(CollectionElement, ExplicitlyCopyable): fn copy(self) -> Self: return self + fn write_to[W: Writer](self, mut writer: W): + writer.write("CopyCounter(") + writer.write(String(self.copy_count)) + writer.write(")") + # ===----------------------------------------------------------------------=== # # MoveCounter @@ -117,6 +122,7 @@ struct CopyCounter(CollectionElement, ExplicitlyCopyable): struct MoveCounter[T: CollectionElementNew]( CollectionElement, CollectionElementNew, + Writable, ): """Counts the number of moves performed on a value.""" @@ -155,6 +161,11 @@ struct MoveCounter[T: CollectionElementNew]( fn copy(self) -> Self: return self + fn write_to[W: Writer](self, mut writer: W): + writer.write("MoveCounter(") + writer.write(String(self.move_count)) + writer.write(")") + # ===----------------------------------------------------------------------=== # # ValueDestructorRecorder From 772dbbe56758600a781f71ac15658d196d245faf Mon Sep 17 00:00:00 2001 From: Owen Hilyard Date: Wed, 22 Jan 2025 14:31:36 -0500 Subject: [PATCH 2/2] [stdlib] Extended Linked List cleanup and fixups Minor tweaks, mostly consisting of removing unnecessary things like some leftover debug prints and explicit type parameters. Also adds time complexity information to all public methods of `LinkedList`. Signed-off-by: Owen Hilyard --- stdlib/src/builtin/value.mojo | 38 ++++ stdlib/src/collections/linked_list.mojo | 185 +++++++++++++----- stdlib/test/collections/test_linked_list.mojo | 58 +----- stdlib/test/collections/test_list.mojo | 48 +---- stdlib/test/test_utils/__init__.mojo | 3 + stdlib/test/test_utils/types.mojo | 65 +++++- 6 files changed, 246 insertions(+), 151 deletions(-) diff --git a/stdlib/src/builtin/value.mojo b/stdlib/src/builtin/value.mojo index c3868f1a5a..4b991dc06b 100644 --- a/stdlib/src/builtin/value.mojo +++ b/stdlib/src/builtin/value.mojo @@ -291,3 +291,41 @@ trait BoolableKeyElement(Boolable, KeyElement): """ pass + + +trait EqualityComparableWritableCollectionElement( + WritableCollectionElement, EqualityComparable +): + """A trait that combines the CollectionElement, Writable and + EqualityComparable traits. + + This trait requires types to implement CollectionElement, Writable and + EqualityComparable interfaces, allowing them to be used in collections, + compared, and written to output. + """ + + pass + + +trait WritableCollectionElementNew(CollectionElementNew, Writable): + """A trait that combines the CollectionElement and Writable traits. + + This trait requires types to implement both CollectionElement and Writable + interfaces, allowing them to be used in collections and written to output. + """ + + pass + + +trait EqualityComparableWritableCollectionElementNew( + WritableCollectionElementNew, EqualityComparable +): + """A trait that combines the CollectionElement, Writable and + EqualityComparable traits. + + This trait requires types to implement CollectionElement, Writable and + EqualityComparable interfaces, allowing them to be used in collections, + compared, and written to output. + """ + + pass diff --git a/stdlib/src/collections/linked_list.mojo b/stdlib/src/collections/linked_list.mojo index 51c2b51988..51780b4153 100644 --- a/stdlib/src/collections/linked_list.mojo +++ b/stdlib/src/collections/linked_list.mojo @@ -14,6 +14,7 @@ from memory import UnsafePointer from collections import Optional from collections._index_normalization import normalize_index +from os import abort @value @@ -26,20 +27,20 @@ struct Node[ ElementType: The type of element stored in the node. """ - alias NodePointer = UnsafePointer[Self] + alias _NodePointer = UnsafePointer[Self] var value: ElementType """The value stored in this node.""" - var prev: Self.NodePointer + var prev: Self._NodePointer """The previous node in the list.""" - var next: Self.NodePointer + var next: Self._NodePointer """The next node in the list.""" fn __init__( out self, owned value: ElementType, - prev: Optional[Self.NodePointer], - next: Optional[Self.NodePointer], + prev: Optional[Self._NodePointer], + next: Optional[Self._NodePointer], ): """Initialize a new Node with the given value and optional prev/next pointers. @@ -50,8 +51,8 @@ struct Node[ next: Optional pointer to the next node. """ self.value = value^ - self.prev = prev.value() if prev else __type_of(self.prev)() - self.next = next.value() if next else __type_of(self.next)() + self.prev = prev.value() if prev else Self._NodePointer() + self.next = next.value() if next else Self._NodePointer() fn __str__[ ElementType: WritableCollectionElement @@ -98,24 +99,30 @@ struct LinkedList[ CollectionElement. """ - alias NodePointer = UnsafePointer[Node[ElementType]] + alias _NodePointer = UnsafePointer[Node[ElementType]] - var _head: Self.NodePointer + var _head: Self._NodePointer """The first node in the list.""" - var _tail: Self.NodePointer + var _tail: Self._NodePointer """The last node in the list.""" var _size: Int """The number of elements in the list.""" fn __init__(out self): - """Initialize an empty linked list.""" - self._head = __type_of(self._head)() - self._tail = __type_of(self._tail)() + """ + Initialize an empty linked list. + + Time Complexity: O(1) + """ + self._head = Self._NodePointer() + self._tail = Self._NodePointer() self._size = 0 fn __init__(mut self, owned *elements: ElementType): """Initialize a linked list with the given elements. + Time Complexity: O(n) in len(elements) + Args: elements: Variable number of elements to initialize the list with. """ @@ -125,6 +132,8 @@ struct LinkedList[ """ Construct a list from a `VariadicListMem`. + Time Complexity: O(n) in len(elements) + Args: elements: The elements to add to the list. """ @@ -134,16 +143,18 @@ struct LinkedList[ for i in range(length): var src = UnsafePointer.address_of(elements[i]) - var node = Self.NodePointer.alloc(1) + var node = Self._NodePointer.alloc(1) + if not node: + abort("Out of memory") var dst = UnsafePointer.address_of(node[].value) src.move_pointee_into(dst) - node[].next = Self.NodePointer() + node[].next = Self._NodePointer() node[].prev = self._tail - if not self._tail: - self._head = node + if self._tail: + self._tail[].next = node self._tail = node else: - self._tail[].next = node + self._head = node self._tail = node # Do not destroy the elements when their backing storage goes away. @@ -156,6 +167,8 @@ struct LinkedList[ fn __copyinit__(mut self, read other: Self): """Initialize this list as a copy of another list. + Time Complexity: O(n) in len(elements) + Args: other: The list to copy from. """ @@ -164,18 +177,24 @@ struct LinkedList[ fn __moveinit__(mut self, owned other: Self): """Initialize this list by moving elements from another list. + Time Complexity: O(1) + Args: other: The list to move elements from. """ self._head = other._head self._tail = other._tail self._size = other._size - other._head = __type_of(other._head)() - other._tail = __type_of(other._tail)() + other._head = Self._NodePointer() + other._tail = Self._NodePointer() other._size = 0 fn __del__(owned self): - """Clean up the list by freeing all nodes.""" + """ + Clean up the list by freeing all nodes. + + Time Complexity: O(n) in len(self) + """ var curr = self._head while curr: var next = curr[].next @@ -184,16 +203,21 @@ struct LinkedList[ curr = next fn append(mut self, owned value: ElementType): - """Add an element to the end of the list. + """ + Add an element to the end of the list. + + Time Complexity: O(1) Args: value: The value to append. """ - var addr = Self.NodePointer.alloc(1) + var addr = Self._NodePointer.alloc(1) + if not addr: + abort("Out of memory") var value_ptr = UnsafePointer.address_of(addr[].value) value_ptr.init_pointee_move(value^) addr[].prev = self._tail - addr[].next = Self.NodePointer() + addr[].next = Self._NodePointer() if self._tail: self._tail[].next = addr else: @@ -202,13 +226,18 @@ struct LinkedList[ self._size += 1 fn prepend(mut self, owned value: ElementType): - """Add an element to the beginning of the list. + """ + Add an element to the beginning of the list. + + Time Complexity: O(1) Args: value: The value to prepend. """ var node = Node(value^, None, self._head) - var addr = UnsafePointer[__type_of(node)].alloc(1) + var addr = Self._NodePointer.alloc(1) + if not addr: + abort("Out of memory") addr.init_pointee_move(node) if self: self._head[].prev = addr @@ -218,8 +247,12 @@ struct LinkedList[ self._size += 1 fn reverse(mut self): - """Reverse the order of elements in the list.""" - var prev = __type_of(self._head)() + """ + Reverse the order of elements in the list. + + Time Complexity: O(n) in len(self) + """ + var prev = Self._NodePointer() var curr = self._head while curr: var next = curr[].next @@ -232,6 +265,8 @@ struct LinkedList[ fn pop(mut self) raises -> ElementType: """Remove and return the last element of the list. + Time Complexity: O(1) + Returns: The last element in the list. """ @@ -243,9 +278,9 @@ struct LinkedList[ self._tail = elem[].prev self._size -= 1 if self._size == 0: - self._head = __type_of(self._head)() + self._head = Self._NodePointer() else: - self._tail[].next = Self.NodePointer() + self._tail[].next = Self._NodePointer() elem.free() return value^ @@ -254,6 +289,8 @@ struct LinkedList[ Remove the ith element of the list, counting from the tail if given a negative index. + Time Complexity: O(1) + Parameters: I: The type of index to use. @@ -289,7 +326,10 @@ struct LinkedList[ return data^ fn pop_if_present(mut self) -> Optional[ElementType]: - """Removes the head of the list and returns it, if it exists. + """ + Removes the head of the list and returns it, if it exists. + + Time Complexity: O(1) Returns: The head of the list, if it was present. @@ -301,9 +341,9 @@ struct LinkedList[ self._tail = elem[].prev self._size -= 1 if self._size == 0: - self._head = __type_of(self._head)() + self._head = Self._NodePointer() else: - self._tail[].next = Self.NodePointer() + self._tail[].next = Self._NodePointer() elem.free() return value^ @@ -314,6 +354,8 @@ struct LinkedList[ Remove the ith element of the list, counting from the tail if given a negative index. + Time Complexity: O(1) + Parameters: I: The type of index to use. @@ -349,7 +391,11 @@ struct LinkedList[ return Optional[ElementType](data^) fn clear(mut self): - """Removes all elements from the list.""" + """ + Removes all elements from the list. + + Time Complexity: O(n) in len(self) + """ var current = self._head while current: var old = current @@ -357,13 +403,15 @@ struct LinkedList[ old.destroy_pointee() old.free() - self._head = Self.NodePointer() - self._tail = Self.NodePointer() + self._head = Self._NodePointer() + self._tail = Self._NodePointer() self._size = 0 fn copy(self) -> Self: """Create a deep copy of the list. + Time Complexity: O(n) in len(self) + Returns: A new list containing copies of all elements. """ @@ -378,18 +426,25 @@ struct LinkedList[ """ Insert an element `elem` into the list at index `idx`. + Time Complexity: O(1) + + Raises: + When given an out of bounds index. + Args: - idx: The index to insert `elem` at. + idx: The index to insert `elem` at. `-len(self) <= idx <= len(self)`. elem: The item to insert into the list. """ var i = max(0, index(idx) if idx >= 0 else index(idx) + len(self)) if i == 0: - var node = Self.NodePointer.alloc(1) + var node = Self._NodePointer.alloc(1) if not node: - raise "OOM" + abort("Out of memory") node.init_pointee_move( - Node[ElementType](elem^, Self.NodePointer(), Self.NodePointer()) + Node[ElementType]( + elem^, Self._NodePointer(), Self._NodePointer() + ) ) if self._head: @@ -409,9 +464,9 @@ struct LinkedList[ var current = self._get_node_ptr(i) if current: var next = current[].next - var node = Self.NodePointer.alloc(1) + var node = Self._NodePointer.alloc(1) if not node: - raise "OOM" + abort("Out of memory") var data = UnsafePointer.address_of(node[].value) data[] = elem^ node[].next = next @@ -419,18 +474,19 @@ struct LinkedList[ if next: next[].prev = node current[].next = node - if node[].next == Self.NodePointer(): + if node[].next == Self._NodePointer(): self._tail = node - if node[].prev == Self.NodePointer(): + if node[].prev == Self._NodePointer(): self._head = node self._size += 1 else: - raise "index out of bounds" + raise String("Index {} out of bounds").format(idx) fn extend(mut self, owned other: Self): """ Extends the list with another. - O(1) time complexity. + + Time Complexity: O(1) Args: other: The list to append to this one. @@ -448,8 +504,8 @@ struct LinkedList[ self._tail = other._tail self._size = other._size - other._head = Self.NodePointer() - other._tail = Self.NodePointer() + other._head = Self._NodePointer() + other._tail = Self._NodePointer() fn count[ ElementType: EqualityComparableCollectionElement @@ -457,6 +513,8 @@ struct LinkedList[ """ Count the occurrences of `elem` in the list. + Time Complexity: O(n) in len(self) compares + Parameters: ElementType: The list element type, used to conditionally enable the function. @@ -482,6 +540,8 @@ struct LinkedList[ """ Checks if the list contains `value`. + Time Complexity: O(n) in len(self) compares + Parameters: ElementType: The list element type, used to conditionally enable the function. @@ -507,6 +567,8 @@ struct LinkedList[ """ Checks if the two lists are equal. + Time Complexity: O(n) in min(len(self), len(other)) compares + Parameters: ElementType: The list element type, used to conditionally enable the function. @@ -537,6 +599,8 @@ struct LinkedList[ """ Checks if the two lists are not equal. + Time Complexity: O(n) in min(len(self), len(other)) compares + Parameters: ElementType: The list element type, used to conditionally enable the function. @@ -549,11 +613,14 @@ struct LinkedList[ return not (self == other) fn _get_node_ptr(ref self, index: Int) -> UnsafePointer[Node[ElementType]]: - """Get a pointer to the node at the specified index. + """ + Get a pointer to the node at the specified index. This method optimizes traversal by starting from either the head or tail depending on which is closer to the target index. + Time Complexity: O(n) in len(self) + Args: index: The index of the node to get. @@ -576,7 +643,10 @@ struct LinkedList[ return curr fn __getitem__(ref self, index: Int) -> ref [self] ElementType: - """Get the element at the specified index. + """ + Get the element at the specified index. + + Time Complexity: O(n) in len(self) Args: index: The index of the element to get. @@ -588,7 +658,10 @@ struct LinkedList[ return self._get_node_ptr(index)[].value fn __setitem__(mut self, index: Int, owned value: ElementType): - """Set the element at the specified index. + """ + Set the element at the specified index. + + Time Complexity: O(n) in len(self) Args: index: The index of the element to set. @@ -600,6 +673,8 @@ struct LinkedList[ fn __len__(self) -> Int: """Get the number of elements in the list. + Time Complexity: O(1) + Returns: The number of elements in the list. """ @@ -608,6 +683,8 @@ struct LinkedList[ fn __bool__(self) -> Bool: """Check if the list is non-empty. + Time Complexity: O(1) + Returns: True if the list has elements, False otherwise. """ @@ -618,6 +695,8 @@ struct LinkedList[ ](self: LinkedList[ElementType]) -> String: """Convert the list to its string representation. + Time Complexity: O(n) in len(self) + Parameters: ElementType: Used to conditionally enable this function when `ElementType` is `Writable`. @@ -634,6 +713,8 @@ struct LinkedList[ ](self: LinkedList[ElementType]) -> String: """Convert the list to its string representation. + Time Complexity: O(n) in len(self) + Parameters: ElementType: Used to conditionally enable this function when `ElementType` is `Writable`. @@ -650,6 +731,8 @@ struct LinkedList[ ](self: LinkedList[ElementType], mut writer: W): """Write the list to the given writer. + Time Complexity: O(n) in len(self) + Parameters: W: The type of writer to write the list to. ElementType: Used to conditionally enable this function when diff --git a/stdlib/test/collections/test_linked_list.mojo b/stdlib/test/collections/test_linked_list.mojo index 6af2d0c4e4..c2d1cda34f 100644 --- a/stdlib/test/collections/test_linked_list.mojo +++ b/stdlib/test/collections/test_linked_list.mojo @@ -14,7 +14,13 @@ from collections import LinkedList, Optional from testing import assert_equal, assert_raises, assert_true, assert_false -from test_utils import CopyCounter, MoveCounter +from test_utils import ( + CopyCounter, + MoveCounter, + DtorCounter, + g_dtor_count, + CopyCountedStruct, +) def test_construction(): @@ -175,7 +181,7 @@ def test_list_pop(): # try popping from index 3 for 3 times for i in range(3, 6): - assert_equal[Int](i, list.pop(3)) + assert_equal(i, list.pop(3)) # list should have 3 elements now assert_equal(3, len(list)) @@ -303,8 +309,6 @@ def test_list_insert(): assert_equal(v1[1], 2) assert_equal(v1[2], 3) - print(v1.__str__()) - # # Test the list [1, 2, 3, 4, 5] created with negative and positive index # @@ -315,7 +319,6 @@ def test_list_insert(): v2.insert(len(v2), 5) v2.insert(-1, 4) v2.insert(-len(v2), 1) - print(v2.__str__()) assert_equal(len(v2), 5) assert_equal(v2[0], 1) @@ -358,9 +361,6 @@ def test_list_extend_non_trivial(): # - extend() for non-plain-old-data types # - extend() with mixed-length self and other lists # - extend() using optimal number of __moveinit__() calls - - # Preallocate with enough capacity to avoid reallocation making the - # move count checks below flaky. var v1 = LinkedList[MoveCounter[String]]() v1.append(MoveCounter[String]("Hello")) v1.append(MoveCounter[String]("World")) @@ -430,21 +430,6 @@ def test_list_explicit_copy(): assert_equal(l2[i], l2_copy[i]) -@value -struct CopyCountedStruct(CollectionElement): - var counter: CopyCounter - var value: String - - fn __init__(out self, *, other: Self): - self.counter = other.counter.copy() - self.value = other.value.copy() - - @implicit - fn __init__(out self, value: String): - self.counter = CopyCounter() - self.value = value - - def test_no_extra_copies_with_sugared_set_by_field(): var list = LinkedList[LinkedList[CopyCountedStruct]]() var child_list = LinkedList[CopyCountedStruct]() @@ -534,33 +519,6 @@ def test_indexing(): # ===-------------------------------------------------------------------===# # LinkedList dtor tests # ===-------------------------------------------------------------------===# -var g_dtor_count: Int = 0 - - -struct DtorCounter(CollectionElement, Writable): - # NOTE: payload is required because LinkedList does not support zero sized structs. - var payload: Int - - fn __init__(out self): - self.payload = 0 - - fn __init__(out self, *, other: Self): - self.payload = other.payload - - fn __copyinit__(out self, existing: Self, /): - self.payload = existing.payload - - fn __moveinit__(out self, owned existing: Self, /): - self.payload = existing.payload - existing.payload = 0 - - fn __del__(owned self): - g_dtor_count += 1 - - fn write_to[W: Writer](self, mut writer: W): - writer.write("DtorCounter(") - writer.write(String(g_dtor_count)) - writer.write(")") def inner_test_list_dtor(): diff --git a/stdlib/test/collections/test_list.mojo b/stdlib/test/collections/test_list.mojo index 2ed9a71306..9cec6810c2 100644 --- a/stdlib/test/collections/test_list.mojo +++ b/stdlib/test/collections/test_list.mojo @@ -16,7 +16,13 @@ from collections import List from sys.info import sizeof from memory import UnsafePointer, Span -from test_utils import CopyCounter, MoveCounter +from test_utils import ( + CopyCounter, + MoveCounter, + DtorCounter, + g_dtor_count, + CopyCountedStruct, +) from testing import assert_equal, assert_false, assert_raises, assert_true @@ -548,21 +554,6 @@ def test_list_explicit_copy(): assert_equal(l2[i], l2_copy[i]) -@value -struct CopyCountedStruct(CollectionElement): - var counter: CopyCounter - var value: String - - fn __init__(out self, *, other: Self): - self.counter = other.counter.copy() - self.value = other.value.copy() - - @implicit - fn __init__(out self, value: String): - self.counter = CopyCounter() - self.value = value - - def test_no_extra_copies_with_sugared_set_by_field(): var list = List[List[CopyCountedStruct]](capacity=1) var child_list = List[CopyCountedStruct](capacity=2) @@ -872,31 +863,6 @@ def test_indexing(): # ===-------------------------------------------------------------------===# # List dtor tests # ===-------------------------------------------------------------------===# -var g_dtor_count: Int = 0 - - -struct DtorCounter(CollectionElement): - # NOTE: payload is required because List does not support zero sized structs. - var payload: Int - - fn __init__(out self): - self.payload = 0 - - fn __init__(out self, *, other: Self): - self.payload = other.payload - - fn __copyinit__(out self, existing: Self, /): - self.payload = existing.payload - - fn copy(self) -> Self: - return self - - fn __moveinit__(out self, owned existing: Self, /): - self.payload = existing.payload - existing.payload = 0 - - fn __del__(owned self): - g_dtor_count += 1 def inner_test_list_dtor(): diff --git a/stdlib/test/test_utils/__init__.mojo b/stdlib/test/test_utils/__init__.mojo index 3dcecfc83b..704510e623 100644 --- a/stdlib/test/test_utils/__init__.mojo +++ b/stdlib/test/test_utils/__init__.mojo @@ -20,4 +20,7 @@ from .types import ( MoveOnly, ObservableDel, ValueDestructorRecorder, + DtorCounter, + g_dtor_count, + CopyCountedStruct, ) diff --git a/stdlib/test/test_utils/types.mojo b/stdlib/test/test_utils/types.mojo index 0610304a36..bd23c56e60 100644 --- a/stdlib/test/test_utils/types.mojo +++ b/stdlib/test/test_utils/types.mojo @@ -122,7 +122,6 @@ struct CopyCounter(CollectionElement, ExplicitlyCopyable, Writable): struct MoveCounter[T: CollectionElementNew]( CollectionElement, CollectionElementNew, - Writable, ): """Counts the number of moves performed on a value.""" @@ -154,17 +153,12 @@ struct MoveCounter[T: CollectionElementNew]( # TODO: This type should not be Copyable, but has to be to satisfy # CollectionElement at the moment. fn __copyinit__(out self, existing: Self): - # print("ERROR: _MoveCounter copy constructor called unexpectedly!") self.value = existing.value.copy() self.move_count = existing.move_count - fn copy(self) -> Self: - return self - - fn write_to[W: Writer](self, mut writer: W): - writer.write("MoveCounter(") - writer.write(String(self.move_count)) - writer.write(")") + fn copy(self, out existing: Self): + existing = Self(self.value.copy()) + existing.move_count = self.move_count # ===----------------------------------------------------------------------=== # @@ -202,3 +196,56 @@ struct ObservableDel(CollectionElement): fn __del__(owned self): self.target.init_pointee_move(True) + + +# ===----------------------------------------------------------------------=== # +# DtorCounter +# ===----------------------------------------------------------------------=== # + +var g_dtor_count: Int = 0 + + +struct DtorCounter(CollectionElement, Writable): + # NOTE: payload is required because LinkedList does not support zero sized structs. + var payload: Int + + fn __init__(out self): + self.payload = 0 + + fn __init__(out self, *, other: Self): + self.payload = other.payload + + fn __copyinit__(out self, existing: Self, /): + self.payload = existing.payload + + fn __moveinit__(out self, owned existing: Self, /): + self.payload = existing.payload + existing.payload = 0 + + fn __del__(owned self): + g_dtor_count += 1 + + fn write_to[W: Writer](self, mut writer: W): + writer.write("DtorCounter(") + writer.write(String(g_dtor_count)) + writer.write(")") + + +# ===----------------------------------------------------------------------=== # +# CopyCountedStruct +# ===----------------------------------------------------------------------=== # + + +@value +struct CopyCountedStruct(CollectionElement): + var counter: CopyCounter + var value: String + + fn __init__(out self, *, other: Self): + self.counter = other.counter.copy() + self.value = other.value.copy() + + @implicit + fn __init__(out self, value: String): + self.counter = CopyCounter() + self.value = value