From 2581aa0106dc271eb7b40ea6e91f4ef876ed8787 Mon Sep 17 00:00:00 2001 From: Yinon Burgansky Date: Thu, 23 Jan 2025 09:18:00 +0200 Subject: [PATCH] [stdlib] Support performant negative indexing with UInt.MAX Signed-off-by: Yinon Burgansky --- .../src/collections/_index_normalization.mojo | 57 ++++++++++--------- .../collections/test_index_normalization.mojo | 21 ++++++- 2 files changed, 47 insertions(+), 31 deletions(-) diff --git a/stdlib/src/collections/_index_normalization.mojo b/stdlib/src/collections/_index_normalization.mojo index 4191c7bf6c..87c59b8ecb 100644 --- a/stdlib/src/collections/_index_normalization.mojo +++ b/stdlib/src/collections/_index_normalization.mojo @@ -45,46 +45,47 @@ fn normalize_index[ or _type_is_eq[IdxType, UInt64]() ): var i = UInt(index(idx)) + # TODO: Consider a way to construct the error message after the assert has failed + # something like "Indexing into an empty container" if length == 0 else "..." debug_assert[assert_mode="safe", cpu_only=True]( i < length, container_name, - " index out of bounds: ", + " index out of bounds: index (", i, - " should be less than ", + ") valid range: -", # can't print -UInt.MAX + length, + " <= index < ", length, ) return i else: - # Optimize for the common case: - # Proper comparison between Int and UInt is slower and containers with - # more than Int.MAX elements are rare. - # Don't use "safe" since this is considered an overflow error. - debug_assert( - length <= UInt(Int.MAX), - "Overflow Error: ", - container_name, - " length is grater than Int.MAX (", - length, - "). Consider indexing with the UInt type.", - ) - var i = Int(idx) - # TODO: Consider a way to construct the error message after the assert has failed - # something like "Indexing into an empty container" if length == 0 else "..." + var i = UInt(index(idx)) + if Int(i) < 0: + i += length + # Checking the bounds after the normalization saves a comparison + # while allowing negative indexing into containers with length > Int.MAX. + # For a positive index this is trivially correct. + # For a negative index we can infer the full bounds check from + # the assert UInt(idx + length) < length, by considering 2 cases: + # when length > Int.MAX then: + # idx + length > idx + Int.MAX >= Int.MIN + Int.MAX = -1 + # therefore idx + length >= 0 + # when length <= Int.MAX then: + # UInt(idx + length) < length <= Int.MAX + # Which means UInt(idx + length) signed bit is off + # therefore idx + length >= 0 + # in either case we can infer 0 <= idx + length < length debug_assert[assert_mode="safe", cpu_only=True]( - -Int(length) <= i < Int(length), + i < length, container_name, - " has length: ", + " index out of bounds: index (", + Int(idx), + ") valid range: -", # can't print -UInt.MAX + length, + " <= index < ", length, - " index out of bounds: ", - i, - " should be between ", - -Int(length), - " and ", - length - 1, ) - if i >= 0: - return i - return i + length + return i @always_inline diff --git a/stdlib/test/collections/test_index_normalization.mojo b/stdlib/test/collections/test_index_normalization.mojo index 83a60bcec2..437c702cd7 100644 --- a/stdlib/test/collections/test_index_normalization.mojo +++ b/stdlib/test/collections/test_index_normalization.mojo @@ -45,6 +45,17 @@ def test_out_of_bounds_message(): # CHECK: index out of bounds _ = normalize_index[""](UInt(2), UInt(0)) + # CHECK: index out of bounds + _ = normalize_index[""](Int.MIN, 10) + # CHECK: index out of bounds + _ = normalize_index[""](Int.MIN, UInt(10)) + # CHECK: index out of bounds + _ = normalize_index[""](Int.MAX, 10) + # CHECK: index out of bounds + _ = normalize_index[""](Int.MAX, UInt(10)) + # CHECK: index out of bounds + _ = normalize_index[""](Int.MIN, Int.MAX) + # CHECK: index out of bounds _ = normalize_index[""](UInt.MAX, 10) # CHECK: index out of bounds @@ -54,9 +65,6 @@ def test_out_of_bounds_message(): # CHECK: index out of bounds _ = normalize_index[""](UInt.MAX, UInt.MAX - 10) - # CHECK: Overflow Error - _ = normalize_index[""](-1, UInt(Int.MAX + 1)) - def test_normalize_index(): assert_equal(normalize_index[""](-3, 3), 0) @@ -91,6 +99,13 @@ def test_normalize_index(): assert_equal(normalize_index[""](UInt(1), UInt.MAX), 1) assert_equal(normalize_index[""](UInt.MAX - 5, UInt.MAX), UInt.MAX - 5) + assert_equal(normalize_index[""](-1, Int.MAX), Int.MAX - 1) + assert_equal(normalize_index[""](-10, Int.MAX), Int.MAX - 10) + assert_equal(normalize_index[""](-1, UInt.MAX), UInt.MAX - 1) + assert_equal(normalize_index[""](-10, UInt.MAX), UInt.MAX - 10) + assert_equal(normalize_index[""](-1, UInt(Int.MAX) + 1), UInt(Int.MAX)) + assert_equal(normalize_index[""](Int.MIN, UInt(Int.MAX) + 1), 0) + def main(): test_out_of_bounds_message()