From 3d2471f37235a40d870d78e788541062e29b27c3 Mon Sep 17 00:00:00 2001 From: sakno Date: Sun, 14 Jul 2024 06:02:30 +0300 Subject: [PATCH] Fixed #247 --- .../Collections/Generic/CollectionTests.cs | 69 ++++++++++++++++++ .../Collections/Generic/AsyncEnumerable.cs | 33 ++++----- src/DotNext/Collections/Generic/Collection.cs | 73 +++++++++++++++++-- src/DotNext/Collections/Generic/Enumerator.cs | 8 +- src/DotNext/Span.cs | 14 +++- 5 files changed, 165 insertions(+), 32 deletions(-) diff --git a/src/DotNext.Tests/Collections/Generic/CollectionTests.cs b/src/DotNext.Tests/Collections/Generic/CollectionTests.cs index f4a08f3d5..1e56a7884 100644 --- a/src/DotNext.Tests/Collections/Generic/CollectionTests.cs +++ b/src/DotNext.Tests/Collections/Generic/CollectionTests.cs @@ -1,4 +1,6 @@ using System.Collections.Concurrent; +using System.Collections.Immutable; +using System.Runtime.InteropServices; namespace DotNext.Collections.Generic; @@ -79,6 +81,20 @@ public static void ForEachTest() array2.ForEach(counter.Accept); Equal(5, counter.value); } + + [Fact] + public static async Task ForEachTestAsync() + { + IList list = new List { 1, 10, 20 }; + var counter = new Counter(); + await list.ForEachAsync(counter.AcceptAsync); + Equal(3, counter.value); + counter.value = 0; + + var array2 = new int[] { 1, 2, 10, 11, 15 }; + await array2.ForEachAsync(counter.AcceptAsync); + Equal(5, counter.value); + } [Fact] public static void ElementAtIndex() @@ -275,4 +291,57 @@ public static void CopyString() using var copy = "abcd".Copy(); Equal("abcd", copy.Memory.ToString()); } + + [Fact] + public static void FirstOrNone() + { + Equal(5, new[] { 5, 6 }.FirstOrNone()); + Equal(5, new List { 5, 6 }.FirstOrNone()); + Equal(5, new LinkedList([5, 6]).FirstOrNone()); + Equal('5', "56".FirstOrNone()); + Equal(5, ImmutableArray.Create([5, 6]).FirstOrNone()); + Equal(5, GetValues().FirstOrNone()); + + Equal(Optional.None, Array.Empty().FirstOrNone()); + Equal(Optional.None, new List().FirstOrNone()); + Equal(Optional.None, new LinkedList().FirstOrNone()); + Equal(Optional.None, string.Empty.FirstOrNone()); + Equal(Optional.None, ImmutableArray.Empty.FirstOrNone()); + Equal(Optional.None, EmptyEnumerable().FirstOrNone()); + + static IEnumerable GetValues() + { + yield return 5; + yield return 6; + } + } + + [Fact] + public static void LastOrNone() + { + Equal(6, new[] { 5, 6 }.LastOrNone()); + Equal(6, new List { 5, 6 }.LastOrNone()); + Equal(6, new LinkedList([5, 6]).LastOrNone()); + Equal('6', "56".LastOrNone()); + Equal(6, ImmutableArray.Create([5, 6]).LastOrNone()); + Equal(6, GetValues().LastOrNone()); + + Equal(Optional.None, Array.Empty().LastOrNone()); + Equal(Optional.None, new List().LastOrNone()); + Equal(Optional.None, new LinkedList().LastOrNone()); + Equal(Optional.None, string.Empty.LastOrNone()); + Equal(Optional.None, ImmutableArray.Empty.LastOrNone()); + Equal(Optional.None, EmptyEnumerable().LastOrNone()); + + static IEnumerable GetValues() + { + yield return 5; + yield return 6; + } + } + + static IEnumerable EmptyEnumerable() + { + yield break; + } } \ No newline at end of file diff --git a/src/DotNext/Collections/Generic/AsyncEnumerable.cs b/src/DotNext/Collections/Generic/AsyncEnumerable.cs index 355789a4d..5eb5bcdf2 100644 --- a/src/DotNext/Collections/Generic/AsyncEnumerable.cs +++ b/src/DotNext/Collections/Generic/AsyncEnumerable.cs @@ -8,7 +8,7 @@ namespace DotNext.Collections.Generic; public static partial class AsyncEnumerable { /// - /// Applies specified action to each collection element asynchronously. + /// Applies specified action to each element of the collection asynchronously. /// /// Type of elements in the collection. /// A collection to enumerate. Cannot be . @@ -23,7 +23,7 @@ public static async ValueTask ForEachAsync(this IAsyncEnumerable collectio } /// - /// Applies specified action to each collection element asynchronously. + /// Applies the specified action to each element of the collection asynchronously. /// /// Type of elements in the collection. /// A collection to enumerate. Cannot be . @@ -38,8 +38,8 @@ public static async ValueTask ForEachAsync(this IAsyncEnumerable collectio } /// - /// Obtains first value type in the sequence; or - /// if sequence is empty. + /// Obtains the first value of a sequence; or + /// if the sequence is empty. /// /// Type of elements in the sequence. /// A sequence to check. Cannot be . @@ -55,8 +55,8 @@ public static async ValueTask ForEachAsync(this IAsyncEnumerable collectio } /// - /// Obtains the last value type in the sequence; or - /// if sequence is empty. + /// Obtains the last value of a sequence; or + /// if the sequence is empty. /// /// Type of elements in the sequence. /// A sequence to check. Cannot be . @@ -74,8 +74,8 @@ public static async ValueTask ForEachAsync(this IAsyncEnumerable collectio } /// - /// Obtains first element in the sequence; or - /// if sequence is empty. + /// Obtains the first element of a sequence; or + /// if the sequence is empty. /// /// Type of elements in the sequence. /// A sequence to check. Cannot be . @@ -90,8 +90,8 @@ public static async ValueTask> FirstOrNoneAsync(this IAsyncEnumer } /// - /// Obtains the last element in the sequence; or - /// if sequence is empty. + /// Obtains the last element of a sequence; or + /// if the sequence is empty. /// /// Type of elements in the sequence. /// A sequence to check. Cannot be . @@ -118,6 +118,8 @@ public static async ValueTask> LastOrNoneAsync(this IAsyncEnumera /// The operation has been canceled. public static async ValueTask> FirstOrNoneAsync(this IAsyncEnumerable seq, Predicate filter, CancellationToken token = default) { + ArgumentNullException.ThrowIfNull(filter); + await foreach (var item in seq.WithCancellation(token).ConfigureAwait(false)) { if (filter(item)) @@ -137,11 +139,10 @@ public static async ValueTask> FirstOrNoneAsync(this IAsyncEnumer /// The operation has been canceled. public static async ValueTask SkipAsync(this IAsyncEnumerator enumerator, int count) { - while (count > 0) + for (; count > 0; count--) { if (!await enumerator.MoveNextAsync().ConfigureAwait(false)) return false; - count--; } return true; @@ -161,11 +162,9 @@ public static async ValueTask> ElementAtAsync(this IAsyncEnumerab var enumerator = collection.GetAsyncEnumerator(token); await using (enumerator.ConfigureAwait(false)) { - await enumerator.SkipAsync(index).ConfigureAwait(false); - - return await enumerator.MoveNextAsync().ConfigureAwait(false) ? - enumerator.Current : - Optional.None; + return await enumerator.SkipAsync(index).ConfigureAwait(false) && await enumerator.MoveNextAsync().ConfigureAwait(false) + ? enumerator.Current + : Optional.None; } } diff --git a/src/DotNext/Collections/Generic/Collection.cs b/src/DotNext/Collections/Generic/Collection.cs index 3a2747efb..8fd65623f 100644 --- a/src/DotNext/Collections/Generic/Collection.cs +++ b/src/DotNext/Collections/Generic/Collection.cs @@ -165,6 +165,67 @@ public static async ValueTask ForEachAsync(this IEnumerable collection, Fu await action.Invoke(item, token).ConfigureAwait(false); } + /// + /// Obtains the first element of a sequence; or + /// if the sequence is empty. + /// + /// The collection to return the first element of. + /// The type of the element of a collection. + /// The first element; or + public static Optional FirstOrNone(this IEnumerable collection) + { + return collection switch + { + null => throw new ArgumentNullException(nameof(collection)), + List list => Span.FirstOrNone(CollectionsMarshal.AsSpan(list)), + T[] array => Span.FirstOrNone(array), + string str => Unsafe.BitCast, Optional>(Span.FirstOrNone(str)), + LinkedList list => list.First is { } first ? first.Value : Optional.None, + IList list => list.Count > 0 ? list[0] : Optional.None, + IReadOnlyList readOnlyList => readOnlyList.Count > 0 ? readOnlyList[0] : Optional.None, + _ => FirstOrNoneSlow(collection), + }; + + static Optional FirstOrNoneSlow(IEnumerable collection) + { + using var enumerator = collection.GetEnumerator(); + return enumerator.MoveNext() ? enumerator.Current : Optional.None; + } + } + + /// + /// Obtains the last element of a sequence; or + /// if the sequence is empty. + /// + /// The collection to return the first element of. + /// The type of the element of a collection. + /// The first element; or + public static Optional LastOrNone(this IEnumerable collection) + { + return collection switch + { + null => throw new ArgumentNullException(nameof(collection)), + List list => Span.LastOrNone(CollectionsMarshal.AsSpan(list)), + T[] array => Span.LastOrNone(array), + string str => Unsafe.BitCast, Optional>(Span.LastOrNone(str)), + LinkedList list => list.Last is { } last ? last.Value : Optional.None, + IList list => list.Count > 0 ? list[^1] : Optional.None, + IReadOnlyList readOnlyList => readOnlyList.Count > 0 ? readOnlyList[^1] : Optional.None, + _ => LastOrNoneSlow(collection), + }; + + static Optional LastOrNoneSlow(IEnumerable collection) + { + var result = Optional.None(); + foreach (var item in collection) + { + result = item; + } + + return result; + } + } + /// /// Obtains element at the specified index in the sequence. /// @@ -181,6 +242,7 @@ public static bool ElementAt(this IEnumerable collection, int index, [Mayb { return collection switch { + null => throw new ArgumentNullException(nameof(collection)), List list => Span.ElementAt(CollectionsMarshal.AsSpan(list), index, out element), T[] array => Span.ElementAt(array, index, out element), LinkedList list => NodeValueAt(list, index, out element), @@ -211,14 +273,15 @@ static bool NodeValueAt(LinkedList list, int matchIndex, [MaybeNullWhen(false static bool ElementAtSlow(IEnumerable collection, int index, [MaybeNullWhen(false)] out T element) { using var enumerator = collection.GetEnumerator(); - enumerator.Skip(index); - if (enumerator.MoveNext()) + + // enumerator.Skip(index + 1) may overflow, replace it with two calls + if (enumerator.Skip(index) && enumerator.MoveNext()) { element = enumerator.Current; return true; } - element = default!; + element = default; return false; } @@ -230,7 +293,7 @@ static bool ListElementAt(IList list, int index, [MaybeNullWhen(false)] out T return true; } - element = default!; + element = default; return false; } @@ -242,7 +305,7 @@ static bool ReadOnlyListElementAt(IReadOnlyList list, int index, [MaybeNullWh return true; } - element = default!; + element = default; return false; } } diff --git a/src/DotNext/Collections/Generic/Enumerator.cs b/src/DotNext/Collections/Generic/Enumerator.cs index c451ddfff..abadeb7d8 100644 --- a/src/DotNext/Collections/Generic/Enumerator.cs +++ b/src/DotNext/Collections/Generic/Enumerator.cs @@ -19,12 +19,10 @@ public static partial class Enumerator /// , if current element is available; otherwise, . public static bool Skip(this IEnumerator enumerator, int count) { - while (count > 0) + for (; count > 0; count--) { if (!enumerator.MoveNext()) return false; - - count--; } return true; @@ -41,12 +39,10 @@ public static bool Skip(this IEnumerator enumerator, int count) public static bool Skip(this ref TEnumerator enumerator, int count) where TEnumerator : struct, IEnumerator { - while (count > 0) + for (; count > 0; count--) { if (!enumerator.MoveNext()) return false; - - count--; } return true; diff --git a/src/DotNext/Span.cs b/src/DotNext/Span.cs index b203057b6..9644d7193 100644 --- a/src/DotNext/Span.cs +++ b/src/DotNext/Span.cs @@ -419,11 +419,10 @@ public static void CopyTo(this Span source, Span destination, out int w /// The source span. /// A function to test each element for a condition. /// The first element in the span that matches to the specified filter; or . - /// is . - public static Optional FirstOrNone(this ReadOnlySpan span, Predicate filter) + public static Optional FirstOrNone(this ReadOnlySpan span, Predicate? filter = null) { - ArgumentNullException.ThrowIfNull(filter); - + filter ??= Predicate.Constant(true); + for (var i = 0; i < span.Length; i++) { var item = span[i]; @@ -434,6 +433,13 @@ public static Optional FirstOrNone(this ReadOnlySpan span, Predicate return Optional.None; } + internal static Optional LastOrNone(ReadOnlySpan span) + { + ref var elementRef = ref MemoryMarshal.GetReference(span); + var length = span.Length; + return length > 0 ? Unsafe.Add(ref elementRef, length - 1) : Optional.None; + } + internal static bool ElementAt(ReadOnlySpan span, int index, [MaybeNullWhen(false)] out T element) { if ((uint)index < (uint)span.Length)