diff --git a/stdlib/src/builtin/value.mojo b/stdlib/src/builtin/value.mojo index c3868f1a5a..4b991dc06b 100644 --- a/stdlib/src/builtin/value.mojo +++ b/stdlib/src/builtin/value.mojo @@ -291,3 +291,41 @@ trait BoolableKeyElement(Boolable, KeyElement): """ pass + + +trait EqualityComparableWritableCollectionElement( + WritableCollectionElement, EqualityComparable +): + """A trait that combines the CollectionElement, Writable and + EqualityComparable traits. + + This trait requires types to implement CollectionElement, Writable and + EqualityComparable interfaces, allowing them to be used in collections, + compared, and written to output. + """ + + pass + + +trait WritableCollectionElementNew(CollectionElementNew, Writable): + """A trait that combines the CollectionElement and Writable traits. + + This trait requires types to implement both CollectionElement and Writable + interfaces, allowing them to be used in collections and written to output. + """ + + pass + + +trait EqualityComparableWritableCollectionElementNew( + WritableCollectionElementNew, EqualityComparable +): + """A trait that combines the CollectionElement, Writable and + EqualityComparable traits. + + This trait requires types to implement CollectionElement, Writable and + EqualityComparable interfaces, allowing them to be used in collections, + compared, and written to output. + """ + + pass diff --git a/stdlib/src/collections/linked_list.mojo b/stdlib/src/collections/linked_list.mojo index cb2e362749..51780b4153 100644 --- a/stdlib/src/collections/linked_list.mojo +++ b/stdlib/src/collections/linked_list.mojo @@ -14,28 +14,33 @@ from memory import UnsafePointer from collections import Optional from collections._index_normalization import normalize_index +from os import abort @value -struct Node[ElementType: WritableCollectionElement]: +struct Node[ + ElementType: CollectionElement, +]: """A node in a linked list data structure. Parameters: ElementType: The type of element stored in the node. """ + alias _NodePointer = UnsafePointer[Self] + var value: ElementType """The value stored in this node.""" - var prev: UnsafePointer[Node[ElementType]] + var prev: Self._NodePointer """The previous node in the list.""" - var next: UnsafePointer[Node[ElementType]] + var next: Self._NodePointer """The next node in the list.""" fn __init__( out self, owned value: ElementType, - prev: Optional[UnsafePointer[Node[ElementType]]], - next: Optional[UnsafePointer[Node[ElementType]]], + prev: Optional[Self._NodePointer], + next: Optional[Self._NodePointer], ): """Initialize a new Node with the given value and optional prev/next pointers. @@ -46,22 +51,32 @@ struct Node[ElementType: WritableCollectionElement]: next: Optional pointer to the next node. """ self.value = value^ - self.prev = prev.value() if prev else __type_of(self.prev)() - self.next = next.value() if next else __type_of(self.next)() + self.prev = prev.value() if prev else Self._NodePointer() + self.next = next.value() if next else Self._NodePointer() - fn __str__(self) -> String: + fn __str__[ + ElementType: WritableCollectionElement + ](self: Node[ElementType]) -> String: """Convert this node's value to a string representation. + Parameters: + ElementType: Used to conditionally enable this function if + `ElementType` is `Writable`. + Returns: String representation of the node's value. """ - return String.write(self) + return String.write(self.value) @no_inline - fn write_to[W: Writer](self, mut writer: W): + fn write_to[ + ElementType: WritableCollectionElement, W: Writer + ](self: Node[ElementType], mut writer: W): """Write this node's value to the given writer. Parameters: + ElementType: Used to conditionally enable this function if + `ElementType` is `Writable`. W: The type of writer to write the value to. Args: @@ -70,7 +85,9 @@ struct Node[ElementType: WritableCollectionElement]: writer.write(self.value) -struct LinkedList[ElementType: WritableCollectionElement]: +struct LinkedList[ + ElementType: CollectionElement, +]: """A doubly-linked list implementation. A doubly-linked list is a data structure where each element points to both @@ -79,71 +96,105 @@ struct LinkedList[ElementType: WritableCollectionElement]: Parameters: ElementType: The type of elements stored in the list. Must implement - WritableCollectionElement. + CollectionElement. """ - var _head: UnsafePointer[Node[ElementType]] + alias _NodePointer = UnsafePointer[Node[ElementType]] + + var _head: Self._NodePointer """The first node in the list.""" - var _tail: UnsafePointer[Node[ElementType]] + var _tail: Self._NodePointer """The last node in the list.""" var _size: Int """The number of elements in the list.""" fn __init__(out self): - """Initialize an empty linked list.""" - self._head = __type_of(self._head)() - self._tail = __type_of(self._tail)() + """ + Initialize an empty linked list. + + Time Complexity: O(1) + """ + self._head = Self._NodePointer() + self._tail = Self._NodePointer() self._size = 0 fn __init__(mut self, owned *elements: ElementType): """Initialize a linked list with the given elements. + Time Complexity: O(n) in len(elements) + Args: elements: Variable number of elements to initialize the list with. """ self = Self(elements=elements^) fn __init__(out self, *, owned elements: VariadicListMem[ElementType, _]): - """Initialize a linked list with the given elements. + """ + Construct a list from a `VariadicListMem`. + + Time Complexity: O(n) in len(elements) Args: - elements: Variable number of elements to initialize the list with. + elements: The elements to add to the list. """ self = Self() - for elem in elements: - self.append(elem[]) + var length = len(elements) + + for i in range(length): + var src = UnsafePointer.address_of(elements[i]) + var node = Self._NodePointer.alloc(1) + if not node: + abort("Out of memory") + var dst = UnsafePointer.address_of(node[].value) + src.move_pointee_into(dst) + node[].next = Self._NodePointer() + node[].prev = self._tail + if self._tail: + self._tail[].next = node + self._tail = node + else: + self._head = node + self._tail = node # Do not destroy the elements when their backing storage goes away. __mlir_op.`lit.ownership.mark_destroyed`( __get_mvalue_as_litref(elements) ) + self._size = length + fn __copyinit__(mut self, read other: Self): """Initialize this list as a copy of another list. + Time Complexity: O(n) in len(elements) + Args: other: The list to copy from. """ - self._head = other._head - self._tail = other._tail - self._size = other._size + self = other.copy() fn __moveinit__(mut self, owned other: Self): """Initialize this list by moving elements from another list. + Time Complexity: O(1) + Args: other: The list to move elements from. """ self._head = other._head self._tail = other._tail self._size = other._size - other._head = __type_of(other._head)() - other._tail = __type_of(other._tail)() + other._head = Self._NodePointer() + other._tail = Self._NodePointer() other._size = 0 fn __del__(owned self): - """Clean up the list by freeing all nodes.""" + """ + Clean up the list by freeing all nodes. + + Time Complexity: O(n) in len(self) + """ var curr = self._head while curr: var next = curr[].next @@ -152,15 +203,22 @@ struct LinkedList[ElementType: WritableCollectionElement]: curr = next fn append(mut self, owned value: ElementType): - """Add an element to the end of the list. + """ + Add an element to the end of the list. + + Time Complexity: O(1) Args: value: The value to append. """ - var node = Node(value^, self._tail, None) - var addr = UnsafePointer[__type_of(node)].alloc(1) - addr.init_pointee_move(node) - if self: + var addr = Self._NodePointer.alloc(1) + if not addr: + abort("Out of memory") + var value_ptr = UnsafePointer.address_of(addr[].value) + value_ptr.init_pointee_move(value^) + addr[].prev = self._tail + addr[].next = Self._NodePointer() + if self._tail: self._tail[].next = addr else: self._head = addr @@ -168,13 +226,18 @@ struct LinkedList[ElementType: WritableCollectionElement]: self._size += 1 fn prepend(mut self, owned value: ElementType): - """Add an element to the beginning of the list. + """ + Add an element to the beginning of the list. + + Time Complexity: O(1) Args: value: The value to prepend. """ var node = Node(value^, None, self._head) - var addr = UnsafePointer[__type_of(node)].alloc(1) + var addr = Self._NodePointer.alloc(1) + if not addr: + abort("Out of memory") addr.init_pointee_move(node) if self: self._head[].prev = addr @@ -184,8 +247,12 @@ struct LinkedList[ElementType: WritableCollectionElement]: self._size += 1 fn reverse(mut self): - """Reverse the order of elements in the list.""" - var prev = __type_of(self._head)() + """ + Reverse the order of elements in the list. + + Time Complexity: O(n) in len(self) + """ + var prev = Self._NodePointer() var curr = self._head while curr: var next = curr[].next @@ -195,23 +262,156 @@ struct LinkedList[ElementType: WritableCollectionElement]: self._tail = self._head self._head = prev - fn pop(mut self) -> ElementType: - """Remove and return the first element of the list. + fn pop(mut self) raises -> ElementType: + """Remove and return the last element of the list. + + Time Complexity: O(1) Returns: - The first element in the list. + The last element in the list. """ var elem = self._tail + if not elem: + raise "Pop on empty list." + var value = elem[].value self._tail = elem[].prev self._size -= 1 if self._size == 0: - self._head = __type_of(self._head)() + self._head = Self._NodePointer() + else: + self._tail[].next = Self._NodePointer() + elem.free() return value^ + fn pop[I: Indexer](mut self, owned i: I) raises -> ElementType: + """ + Remove the ith element of the list, counting from the tail if + given a negative index. + + Time Complexity: O(1) + + Parameters: + I: The type of index to use. + + Args: + i: The index of the element to get. + + Returns: + Ownership of the indicated element. + """ + var current = self._get_node_ptr(Int(i)) + + if not current: + raise "Invalid index for pop" + else: + var node = current[] + if node.prev: + node.prev[].next = node.next + else: + self._head = node.next + if node.next: + node.next[].prev = node.prev + else: + self._tail = node.prev + + var data = node.value^ + + # Aside from T, destructor is trivial + __mlir_op.`lit.ownership.mark_destroyed`( + __get_mvalue_as_litref(node) + ) + current.free() + self._size -= 1 + return data^ + + fn pop_if_present(mut self) -> Optional[ElementType]: + """ + Removes the head of the list and returns it, if it exists. + + Time Complexity: O(1) + + Returns: + The head of the list, if it was present. + """ + var elem = self._tail + if not elem: + return Optional[ElementType]() + var value = elem[].value + self._tail = elem[].prev + self._size -= 1 + if self._size == 0: + self._head = Self._NodePointer() + else: + self._tail[].next = Self._NodePointer() + elem.free() + return value^ + + fn pop_if_present[ + I: Indexer + ](mut self, owned i: I) -> Optional[ElementType]: + """ + Remove the ith element of the list, counting from the tail if + given a negative index. + + Time Complexity: O(1) + + Parameters: + I: The type of index to use. + + Args: + i: The index of the element to get. + + Returns: + The element, if it was found. + """ + var current = self._get_node_ptr(Int(i)) + + if not current: + return Optional[ElementType]() + else: + var node = current[] + if node.prev: + node.prev[].next = node.next + else: + self._head = node.next + if node.next: + node.next[].prev = node.prev + else: + self._tail = node.prev + + var data = node.value^ + + # Aside from T, destructor is trivial + __mlir_op.`lit.ownership.mark_destroyed`( + __get_mvalue_as_litref(node) + ) + current.free() + self._size -= 1 + return Optional[ElementType](data^) + + fn clear(mut self): + """ + Removes all elements from the list. + + Time Complexity: O(n) in len(self) + """ + var current = self._head + while current: + var old = current + current = current[].next + old.destroy_pointee() + old.free() + + self._head = Self._NodePointer() + self._tail = Self._NodePointer() + self._size = 0 + fn copy(self) -> Self: """Create a deep copy of the list. + Time Complexity: O(n) in len(self) + Returns: A new list containing copies of all elements. """ @@ -222,12 +422,205 @@ struct LinkedList[ElementType: WritableCollectionElement]: curr = curr[].next return new^ + fn insert(mut self, owned idx: Int, owned elem: ElementType) raises: + """ + Insert an element `elem` into the list at index `idx`. + + Time Complexity: O(1) + + Raises: + When given an out of bounds index. + + Args: + idx: The index to insert `elem` at. `-len(self) <= idx <= len(self)`. + elem: The item to insert into the list. + """ + var i = max(0, index(idx) if idx >= 0 else index(idx) + len(self)) + + if i == 0: + var node = Self._NodePointer.alloc(1) + if not node: + abort("Out of memory") + node.init_pointee_move( + Node[ElementType]( + elem^, Self._NodePointer(), Self._NodePointer() + ) + ) + + if self._head: + node[].next = self._head + self._head[].prev = node + + self._head = node + + if not self._tail: + self._tail = node + + self._size += 1 + return + + i -= 1 + + var current = self._get_node_ptr(i) + if current: + var next = current[].next + var node = Self._NodePointer.alloc(1) + if not node: + abort("Out of memory") + var data = UnsafePointer.address_of(node[].value) + data[] = elem^ + node[].next = next + node[].prev = current + if next: + next[].prev = node + current[].next = node + if node[].next == Self._NodePointer(): + self._tail = node + if node[].prev == Self._NodePointer(): + self._head = node + self._size += 1 + else: + raise String("Index {} out of bounds").format(idx) + + fn extend(mut self, owned other: Self): + """ + Extends the list with another. + + Time Complexity: O(1) + + Args: + other: The list to append to this one. + """ + if self._tail: + self._tail[].next = other._head + if other._head: + other._head[].prev = self._tail + if other._tail: + self._tail = other._tail + + self._size += other._size + else: + self._head = other._head + self._tail = other._tail + self._size = other._size + + other._head = Self._NodePointer() + other._tail = Self._NodePointer() + + fn count[ + ElementType: EqualityComparableCollectionElement + ](self: LinkedList[ElementType], read elem: ElementType) -> UInt: + """ + Count the occurrences of `elem` in the list. + + Time Complexity: O(n) in len(self) compares + + Parameters: + ElementType: The list element type, used to conditionally enable the function. + + Args: + elem: The element to search for. + + Returns: + The number of occurrences of `elem` in the list. + """ + var current = self._head + var count = 0 + while current: + if current[].value == elem: + count += 1 + + current = current[].next + + return count + + fn __contains__[ + ElementType: EqualityComparableCollectionElement, // + ](self: LinkedList[ElementType], value: ElementType) -> Bool: + """ + Checks if the list contains `value`. + + Time Complexity: O(n) in len(self) compares + + Parameters: + ElementType: The list element type, used to conditionally enable the function. + + Args: + value: The value to search for in the list. + + Returns: + Whether the list contains `value`. + """ + var current = self._head + while current: + if current[].value == value: + return True + current = current[].next + + return False + + fn __eq__[ + ElementType: EqualityComparableCollectionElement, // + ]( + read self: LinkedList[ElementType], read other: LinkedList[ElementType] + ) -> Bool: + """ + Checks if the two lists are equal. + + Time Complexity: O(n) in min(len(self), len(other)) compares + + Parameters: + ElementType: The list element type, used to conditionally enable the function. + + Args: + other: The list to compare to. + + Returns: + Whether the lists are equal. + """ + if self._size != other._size: + return False + + var self_cursor = self._head + var other_cursor = other._head + + while self_cursor: + if self_cursor[].value != other_cursor[].value: + return False + + self_cursor = self_cursor[].next + other_cursor = other_cursor[].next + + return True + + fn __ne__[ + ElementType: EqualityComparableCollectionElement, // + ](self: LinkedList[ElementType], other: LinkedList[ElementType]) -> Bool: + """ + Checks if the two lists are not equal. + + Time Complexity: O(n) in min(len(self), len(other)) compares + + Parameters: + ElementType: The list element type, used to conditionally enable the function. + + Args: + other: The list to compare to. + + Returns: + Whether the lists are not equal. + """ + return not (self == other) + fn _get_node_ptr(ref self, index: Int) -> UnsafePointer[Node[ElementType]]: - """Get a pointer to the node at the specified index. + """ + Get a pointer to the node at the specified index. This method optimizes traversal by starting from either the head or tail depending on which is closer to the target index. + Time Complexity: O(n) in len(self) + Args: index: The index of the node to get. @@ -250,7 +643,10 @@ struct LinkedList[ElementType: WritableCollectionElement]: return curr fn __getitem__(ref self, index: Int) -> ref [self] ElementType: - """Get the element at the specified index. + """ + Get the element at the specified index. + + Time Complexity: O(n) in len(self) Args: index: The index of the element to get. @@ -262,7 +658,10 @@ struct LinkedList[ElementType: WritableCollectionElement]: return self._get_node_ptr(index)[].value fn __setitem__(mut self, index: Int, owned value: ElementType): - """Set the element at the specified index. + """ + Set the element at the specified index. + + Time Complexity: O(n) in len(self) Args: index: The index of the element to set. @@ -274,6 +673,8 @@ struct LinkedList[ElementType: WritableCollectionElement]: fn __len__(self) -> Int: """Get the number of elements in the list. + Time Complexity: O(1) + Returns: The number of elements in the list. """ @@ -282,22 +683,42 @@ struct LinkedList[ElementType: WritableCollectionElement]: fn __bool__(self) -> Bool: """Check if the list is non-empty. + Time Complexity: O(1) + Returns: True if the list has elements, False otherwise. """ return len(self) != 0 - fn __str__(self) -> String: + fn __str__[ + ElementType: WritableCollectionElement + ](self: LinkedList[ElementType]) -> String: """Convert the list to its string representation. + Time Complexity: O(n) in len(self) + + Parameters: + ElementType: Used to conditionally enable this function when + `ElementType` is `Writable`. + Returns: String representation of the list. """ - return String.write(self) + var writer = String() + self._write(writer) + return writer - fn __repr__(self) -> String: + fn __repr__[ + ElementType: WritableCollectionElement + ](self: LinkedList[ElementType]) -> String: """Convert the list to its string representation. + Time Complexity: O(n) in len(self) + + Parameters: + ElementType: Used to conditionally enable this function when + `ElementType` is `Writable`. + Returns: String representation of the list. """ @@ -305,11 +726,17 @@ struct LinkedList[ElementType: WritableCollectionElement]: self._write(writer, prefix="LinkedList(", suffix=")") return writer - fn write_to[W: Writer](self, mut writer: W): + fn write_to[ + W: Writer, ElementType: WritableCollectionElement + ](self: LinkedList[ElementType], mut writer: W): """Write the list to the given writer. + Time Complexity: O(n) in len(self) + Parameters: W: The type of writer to write the list to. + ElementType: Used to conditionally enable this function when + `ElementType` is `Writable`. Args: writer: The writer to write the list to. @@ -318,8 +745,14 @@ struct LinkedList[ElementType: WritableCollectionElement]: @no_inline fn _write[ - W: Writer - ](self, mut writer: W, *, prefix: String = "[", suffix: String = "]"): + W: Writer, ElementType: WritableCollectionElement + ]( + self: LinkedList[ElementType], + mut writer: W, + *, + prefix: String = "[", + suffix: String = "]", + ): if not self: return writer.write(prefix, suffix) @@ -328,6 +761,6 @@ struct LinkedList[ElementType: WritableCollectionElement]: for i in range(len(self)): if i: writer.write(", ") - writer.write(curr[]) + writer.write(curr[].value) curr = curr[].next writer.write(suffix) diff --git a/stdlib/test/collections/test_linked_list.mojo b/stdlib/test/collections/test_linked_list.mojo index a5f3467736..c2d1cda34f 100644 --- a/stdlib/test/collections/test_linked_list.mojo +++ b/stdlib/test/collections/test_linked_list.mojo @@ -12,8 +12,15 @@ # ===----------------------------------------------------------------------=== # # RUN: %mojo %s -from collections import LinkedList -from testing import assert_equal +from collections import LinkedList, Optional +from testing import assert_equal, assert_raises, assert_true, assert_false +from test_utils import ( + CopyCounter, + MoveCounter, + DtorCounter, + g_dtor_count, + CopyCountedStruct, +) def test_construction(): @@ -101,12 +108,439 @@ def test_setitem(): def test_str(): var l1 = LinkedList[Int](1, 2, 3) - assert_equal(String(l1), "[1, 2, 3]") + assert_equal(l1.__str__(), "[1, 2, 3]") def test_repr(): var l1 = LinkedList[Int](1, 2, 3) - assert_equal(repr(l1), "LinkedList(1, 2, 3)") + assert_equal(l1.__repr__(), "LinkedList(1, 2, 3)") + + +def test_pop_on_empty_list(): + with assert_raises(): + var ll = LinkedList[Int]() + _ = ll.pop() + + +def test_optional_pop_on_empty_linked_list(): + var ll = LinkedList[Int]() + var result = ll.pop_if_present() + assert_false(Bool(result)) + + +def test_list(): + var list = LinkedList[Int]() + + for i in range(5): + list.append(i) + + assert_equal(5, len(list)) + assert_equal(0, list[0]) + assert_equal(1, list[1]) + assert_equal(2, list[2]) + assert_equal(3, list[3]) + assert_equal(4, list[4]) + + assert_equal(0, list[-5]) + assert_equal(3, list[-2]) + assert_equal(4, list[-1]) + + list[2] = -2 + assert_equal(-2, list[2]) + + list[-5] = 5 + assert_equal(5, list[-5]) + list[-2] = 3 + assert_equal(3, list[-2]) + list[-1] = 7 + assert_equal(7, list[-1]) + + +def test_list_clear(): + var list = LinkedList[Int](1, 2, 3) + assert_equal(len(list), 3) + list.clear() + + assert_equal(len(list), 0) + + +def test_list_to_bool_conversion(): + assert_false(LinkedList[String]()) + assert_true(LinkedList[String]("a")) + assert_true(LinkedList[String]("", "a")) + assert_true(LinkedList[String]("")) + + +def test_list_pop(): + var list = LinkedList[Int]() + # Test pop with index + for i in range(6): + list.append(i) + + assert_equal(6, len(list)) + + # try popping from index 3 for 3 times + for i in range(3, 6): + assert_equal(i, list.pop(3)) + + # list should have 3 elements now + assert_equal(3, len(list)) + assert_equal(0, list[0]) + assert_equal(1, list[1]) + assert_equal(2, list[2]) + + # Test pop with negative index + for i in range(0, 2): + var popped: Int = list.pop(-len(list)) + assert_equal(i, popped) + + # test default index as well + assert_equal(2, list.pop()) + list.append(2) + assert_equal(2, list.pop()) + + # list should be empty now + assert_equal(0, len(list)) + + +def test_list_variadic_constructor(): + var l = LinkedList[Int](2, 4, 6) + assert_equal(3, len(l)) + assert_equal(2, l[0]) + assert_equal(4, l[1]) + assert_equal(6, l[2]) + + l.append(8) + assert_equal(4, len(l)) + assert_equal(8, l[3]) + + # + # Test variadic construct copying behavior + # + + var l2 = LinkedList[CopyCounter]( + CopyCounter(), CopyCounter(), CopyCounter() + ) + + assert_equal(len(l2), 3) + assert_equal(l2[0].copy_count, 0) + assert_equal(l2[1].copy_count, 0) + assert_equal(l2[2].copy_count, 0) + + +def test_list_reverse(): + # + # Test reversing the list [] + # + + var vec = LinkedList[Int]() + + assert_equal(len(vec), 0) + + vec.reverse() + + assert_equal(len(vec), 0) + + # + # Test reversing the list [123] + # + + vec = LinkedList[Int]() + + vec.append(123) + + assert_equal(len(vec), 1) + assert_equal(vec[0], 123) + + vec.reverse() + + assert_equal(len(vec), 1) + assert_equal(vec[0], 123) + + # + # Test reversing the list ["one", "two", "three"] + # + + var vec2 = LinkedList[String]("one", "two", "three") + + assert_equal(len(vec2), 3) + assert_equal(vec2[0], "one") + assert_equal(vec2[1], "two") + assert_equal(vec2[2], "three") + + vec2.reverse() + + assert_equal(len(vec2), 3) + assert_equal(vec2[0], "three") + assert_equal(vec2[1], "two") + assert_equal(vec2[2], "one") + + # + # Test reversing the list [5, 10] + # + + vec = LinkedList[Int]() + vec.append(5) + vec.append(10) + + assert_equal(len(vec), 2) + assert_equal(vec[0], 5) + assert_equal(vec[1], 10) + + vec.reverse() + + assert_equal(len(vec), 2) + assert_equal(vec[0], 10) + assert_equal(vec[1], 5) + + +def test_list_insert(): + # + # Test the list [1, 2, 3] created with insert + # + + var v1 = LinkedList[Int]() + v1.insert(len(v1), 1) + v1.insert(len(v1), 3) + v1.insert(1, 2) + + assert_equal(len(v1), 3) + assert_equal(v1[0], 1) + assert_equal(v1[1], 2) + assert_equal(v1[2], 3) + + # + # Test the list [1, 2, 3, 4, 5] created with negative and positive index + # + + var v2 = LinkedList[Int]() + v2.insert(-1729, 2) + v2.insert(len(v2), 3) + v2.insert(len(v2), 5) + v2.insert(-1, 4) + v2.insert(-len(v2), 1) + + assert_equal(len(v2), 5) + assert_equal(v2[0], 1) + assert_equal(v2[1], 2) + assert_equal(v2[2], 3) + assert_equal(v2[3], 4) + assert_equal(v2[4], 5) + + # + # Test the list [1, 2, 3, 4] created with negative index + # + + var v3 = LinkedList[Int]() + v3.insert(-11, 4) + v3.insert(-13, 3) + v3.insert(-17, 2) + v3.insert(-19, 1) + + assert_equal(len(v3), 4) + assert_equal(v3[0], 1) + assert_equal(v3[1], 2) + assert_equal(v3[2], 3) + assert_equal(v3[3], 4) + + # + # Test the list [1, 2, 3, 4, 5, 6, 7, 8] created with insert + # + + var v4 = LinkedList[Int]() + for i in range(4): + v4.insert(0, 4 - i) + v4.insert(len(v4), 4 + i + 1) + + for i in range(len(v4)): + assert_equal(v4[i], i + 1) + + +def test_list_extend_non_trivial(): + # Tests three things: + # - extend() for non-plain-old-data types + # - extend() with mixed-length self and other lists + # - extend() using optimal number of __moveinit__() calls + var v1 = LinkedList[MoveCounter[String]]() + v1.append(MoveCounter[String]("Hello")) + v1.append(MoveCounter[String]("World")) + + var v2 = LinkedList[MoveCounter[String]]() + v2.append(MoveCounter[String]("Foo")) + v2.append(MoveCounter[String]("Bar")) + v2.append(MoveCounter[String]("Baz")) + + v1.extend(v2^) + + assert_equal(len(v1), 5) + assert_equal(v1[0].value, "Hello") + assert_equal(v1[1].value, "World") + assert_equal(v1[2].value, "Foo") + assert_equal(v1[3].value, "Bar") + assert_equal(v1[4].value, "Baz") + + assert_equal(v1[0].move_count, 1) + assert_equal(v1[1].move_count, 1) + assert_equal(v1[2].move_count, 1) + assert_equal(v1[3].move_count, 1) + assert_equal(v1[4].move_count, 1) + + +def test_2d_dynamic_list(): + var list = LinkedList[LinkedList[Int]]() + + for i in range(2): + var v = LinkedList[Int]() + for j in range(3): + v.append(i + j) + list.append(v) + + assert_equal(0, list[0][0]) + assert_equal(1, list[0][1]) + assert_equal(2, list[0][2]) + assert_equal(1, list[1][0]) + assert_equal(2, list[1][1]) + assert_equal(3, list[1][2]) + + assert_equal(2, len(list)) + + assert_equal(3, len(list[0])) + + list[0].clear() + assert_equal(0, len(list[0])) + + list.clear() + assert_equal(0, len(list)) + + +def test_list_explicit_copy(): + var list = LinkedList[CopyCounter]() + list.append(CopyCounter()) + var list_copy = list.copy() + assert_equal(0, list[0].copy_count) + assert_equal(1, list_copy[0].copy_count) + + var l2 = LinkedList[Int]() + for i in range(10): + l2.append(i) + + var l2_copy = l2.copy() + assert_equal(len(l2), len(l2_copy)) + for i in range(len(l2)): + assert_equal(l2[i], l2_copy[i]) + + +def test_no_extra_copies_with_sugared_set_by_field(): + var list = LinkedList[LinkedList[CopyCountedStruct]]() + var child_list = LinkedList[CopyCountedStruct]() + child_list.append(CopyCountedStruct("Hello")) + child_list.append(CopyCountedStruct("World")) + + # No copies here. Constructing with LinkedList[CopyCountedStruct](CopyCountedStruct("Hello")) is a copy. + assert_equal(0, child_list[0].counter.copy_count) + assert_equal(0, child_list[1].counter.copy_count) + + list.append(child_list^) + + assert_equal(0, list[0][0].counter.copy_count) + assert_equal(0, list[0][1].counter.copy_count) + + # list[0][1] makes a copy for reasons I cannot determine + list.__getitem__(0).__getitem__(1).value = "Mojo" + + assert_equal(0, list[0][0].counter.copy_count) + assert_equal(0, list[0][1].counter.copy_count) + + assert_equal("Mojo", list[0][1].value) + + assert_equal(0, list[0][0].counter.copy_count) + assert_equal(0, list[0][1].counter.copy_count) + + +def test_list_boolable(): + assert_true(LinkedList[Int](1)) + assert_false(LinkedList[Int]()) + + +def test_list_count(): + var list = LinkedList[Int](1, 2, 3, 2, 5, 6, 7, 8, 9, 10) + assert_equal(1, list.count(1)) + assert_equal(2, list.count(2)) + assert_equal(0, list.count(4)) + + var list2 = LinkedList[Int]() + assert_equal(0, list2.count(1)) + + +def test_list_contains(): + var x = LinkedList[Int](1, 2, 3) + assert_false(0 in x) + assert_true(1 in x) + assert_false(4 in x) + + # TODO: implement LinkedList.__eq__ for Self[ComparableCollectionElement] + # var y = LinkedList[LinkedList[Int]]() + # y.append(LinkedList(1,2)) + # assert_equal(LinkedList(1,2) in y,True) + # assert_equal(LinkedList(0,1) in y,False) + + +def test_list_eq_ne(): + var l1 = LinkedList[Int](1, 2, 3) + var l2 = LinkedList[Int](1, 2, 3) + assert_true(l1 == l2) + assert_false(l1 != l2) + + var l3 = LinkedList[Int](1, 2, 3, 4) + assert_false(l1 == l3) + assert_true(l1 != l3) + + var l4 = LinkedList[Int]() + var l5 = LinkedList[Int]() + assert_true(l4 == l5) + assert_true(l1 != l4) + + var l6 = LinkedList[String]("a", "b", "c") + var l7 = LinkedList[String]("a", "b", "c") + var l8 = LinkedList[String]("a", "b") + assert_true(l6 == l7) + assert_false(l6 != l7) + assert_false(l6 == l8) + + +def test_indexing(): + var l = LinkedList[Int](1, 2, 3) + assert_equal(l[Int(1)], 2) + assert_equal(l[False], 1) + assert_equal(l[True], 2) + assert_equal(l[2], 3) + + +# ===-------------------------------------------------------------------===# +# LinkedList dtor tests +# ===-------------------------------------------------------------------===# + + +def inner_test_list_dtor(): + # explicitly reset global counter + g_dtor_count = 0 + + var l = LinkedList[DtorCounter]() + assert_equal(g_dtor_count, 0) + + l.append(DtorCounter()) + assert_equal(g_dtor_count, 0) + + l^.__del__() + assert_equal(g_dtor_count, 1) + + +def test_list_dtor(): + # call another function to force the destruction of the list + inner_test_list_dtor() + + # verify we still only ran the destructor once + assert_equal(g_dtor_count, 1) def main(): @@ -120,3 +554,22 @@ def main(): test_setitem() test_str() test_repr() + test_pop_on_empty_list() + test_optional_pop_on_empty_linked_list() + test_list() + test_list_clear() + test_list_to_bool_conversion() + test_list_pop() + test_list_variadic_constructor() + test_list_reverse() + test_list_extend_non_trivial() + test_list_explicit_copy() + test_no_extra_copies_with_sugared_set_by_field() + test_2d_dynamic_list() + test_list_boolable() + test_list_count() + test_list_contains() + test_indexing() + test_list_dtor() + test_list_insert() + test_list_eq_ne() diff --git a/stdlib/test/collections/test_list.mojo b/stdlib/test/collections/test_list.mojo index 2ed9a71306..9cec6810c2 100644 --- a/stdlib/test/collections/test_list.mojo +++ b/stdlib/test/collections/test_list.mojo @@ -16,7 +16,13 @@ from collections import List from sys.info import sizeof from memory import UnsafePointer, Span -from test_utils import CopyCounter, MoveCounter +from test_utils import ( + CopyCounter, + MoveCounter, + DtorCounter, + g_dtor_count, + CopyCountedStruct, +) from testing import assert_equal, assert_false, assert_raises, assert_true @@ -548,21 +554,6 @@ def test_list_explicit_copy(): assert_equal(l2[i], l2_copy[i]) -@value -struct CopyCountedStruct(CollectionElement): - var counter: CopyCounter - var value: String - - fn __init__(out self, *, other: Self): - self.counter = other.counter.copy() - self.value = other.value.copy() - - @implicit - fn __init__(out self, value: String): - self.counter = CopyCounter() - self.value = value - - def test_no_extra_copies_with_sugared_set_by_field(): var list = List[List[CopyCountedStruct]](capacity=1) var child_list = List[CopyCountedStruct](capacity=2) @@ -872,31 +863,6 @@ def test_indexing(): # ===-------------------------------------------------------------------===# # List dtor tests # ===-------------------------------------------------------------------===# -var g_dtor_count: Int = 0 - - -struct DtorCounter(CollectionElement): - # NOTE: payload is required because List does not support zero sized structs. - var payload: Int - - fn __init__(out self): - self.payload = 0 - - fn __init__(out self, *, other: Self): - self.payload = other.payload - - fn __copyinit__(out self, existing: Self, /): - self.payload = existing.payload - - fn copy(self) -> Self: - return self - - fn __moveinit__(out self, owned existing: Self, /): - self.payload = existing.payload - existing.payload = 0 - - fn __del__(owned self): - g_dtor_count += 1 def inner_test_list_dtor(): diff --git a/stdlib/test/test_utils/__init__.mojo b/stdlib/test/test_utils/__init__.mojo index 3dcecfc83b..704510e623 100644 --- a/stdlib/test/test_utils/__init__.mojo +++ b/stdlib/test/test_utils/__init__.mojo @@ -20,4 +20,7 @@ from .types import ( MoveOnly, ObservableDel, ValueDestructorRecorder, + DtorCounter, + g_dtor_count, + CopyCountedStruct, ) diff --git a/stdlib/test/test_utils/types.mojo b/stdlib/test/test_utils/types.mojo index 01870acebc..bd23c56e60 100644 --- a/stdlib/test/test_utils/types.mojo +++ b/stdlib/test/test_utils/types.mojo @@ -88,7 +88,7 @@ struct ImplicitCopyOnly(Copyable): # ===----------------------------------------------------------------------=== # -struct CopyCounter(CollectionElement, ExplicitlyCopyable): +struct CopyCounter(CollectionElement, ExplicitlyCopyable, Writable): """Counts the number of copies performed on a value.""" var copy_count: Int @@ -108,6 +108,11 @@ struct CopyCounter(CollectionElement, ExplicitlyCopyable): fn copy(self) -> Self: return self + fn write_to[W: Writer](self, mut writer: W): + writer.write("CopyCounter(") + writer.write(String(self.copy_count)) + writer.write(")") + # ===----------------------------------------------------------------------=== # # MoveCounter @@ -148,12 +153,12 @@ struct MoveCounter[T: CollectionElementNew]( # TODO: This type should not be Copyable, but has to be to satisfy # CollectionElement at the moment. fn __copyinit__(out self, existing: Self): - # print("ERROR: _MoveCounter copy constructor called unexpectedly!") self.value = existing.value.copy() self.move_count = existing.move_count - fn copy(self) -> Self: - return self + fn copy(self, out existing: Self): + existing = Self(self.value.copy()) + existing.move_count = self.move_count # ===----------------------------------------------------------------------=== # @@ -191,3 +196,56 @@ struct ObservableDel(CollectionElement): fn __del__(owned self): self.target.init_pointee_move(True) + + +# ===----------------------------------------------------------------------=== # +# DtorCounter +# ===----------------------------------------------------------------------=== # + +var g_dtor_count: Int = 0 + + +struct DtorCounter(CollectionElement, Writable): + # NOTE: payload is required because LinkedList does not support zero sized structs. + var payload: Int + + fn __init__(out self): + self.payload = 0 + + fn __init__(out self, *, other: Self): + self.payload = other.payload + + fn __copyinit__(out self, existing: Self, /): + self.payload = existing.payload + + fn __moveinit__(out self, owned existing: Self, /): + self.payload = existing.payload + existing.payload = 0 + + fn __del__(owned self): + g_dtor_count += 1 + + fn write_to[W: Writer](self, mut writer: W): + writer.write("DtorCounter(") + writer.write(String(g_dtor_count)) + writer.write(")") + + +# ===----------------------------------------------------------------------=== # +# CopyCountedStruct +# ===----------------------------------------------------------------------=== # + + +@value +struct CopyCountedStruct(CollectionElement): + var counter: CopyCounter + var value: String + + fn __init__(out self, *, other: Self): + self.counter = other.counter.copy() + self.value = other.value.copy() + + @implicit + fn __init__(out self, value: String): + self.counter = CopyCounter() + self.value = value