Skip to content

Commit

Permalink
[stdlib] Support performant negative indexing with UInt.MAX
Browse files Browse the repository at this point in the history
Signed-off-by: Yinon Burgansky <[email protected]>
  • Loading branch information
yinonburgansky committed Jan 23, 2025
1 parent 08417d9 commit 2581aa0
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 31 deletions.
57 changes: 29 additions & 28 deletions stdlib/src/collections/_index_normalization.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -45,46 +45,47 @@ fn normalize_index[
or _type_is_eq[IdxType, UInt64]()
):
var i = UInt(index(idx))
# TODO: Consider a way to construct the error message after the assert has failed
# something like "Indexing into an empty container" if length == 0 else "..."
debug_assert[assert_mode="safe", cpu_only=True](
i < length,
container_name,
" index out of bounds: ",
" index out of bounds: index (",
i,
" should be less than ",
") valid range: -", # can't print -UInt.MAX
length,
" <= index < ",
length,
)
return i
else:
# Optimize for the common case:
# Proper comparison between Int and UInt is slower and containers with
# more than Int.MAX elements are rare.
# Don't use "safe" since this is considered an overflow error.
debug_assert(
length <= UInt(Int.MAX),
"Overflow Error: ",
container_name,
" length is grater than Int.MAX (",
length,
"). Consider indexing with the UInt type.",
)
var i = Int(idx)
# TODO: Consider a way to construct the error message after the assert has failed
# something like "Indexing into an empty container" if length == 0 else "..."
var i = UInt(index(idx))
if Int(i) < 0:
i += length
# Checking the bounds after the normalization saves a comparison
# while allowing negative indexing into containers with length > Int.MAX.
# For a positive index this is trivially correct.
# For a negative index we can infer the full bounds check from
# the assert UInt(idx + length) < length, by considering 2 cases:
# when length > Int.MAX then:
# idx + length > idx + Int.MAX >= Int.MIN + Int.MAX = -1
# therefore idx + length >= 0
# when length <= Int.MAX then:
# UInt(idx + length) < length <= Int.MAX
# Which means UInt(idx + length) signed bit is off
# therefore idx + length >= 0
# in either case we can infer 0 <= idx + length < length
debug_assert[assert_mode="safe", cpu_only=True](
-Int(length) <= i < Int(length),
i < length,
container_name,
" has length: ",
" index out of bounds: index (",
Int(idx),
") valid range: -", # can't print -UInt.MAX
length,
" <= index < ",
length,
" index out of bounds: ",
i,
" should be between ",
-Int(length),
" and ",
length - 1,
)
if i >= 0:
return i
return i + length
return i


@always_inline
Expand Down
21 changes: 18 additions & 3 deletions stdlib/test/collections/test_index_normalization.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ def test_out_of_bounds_message():
# CHECK: index out of bounds
_ = normalize_index[""](UInt(2), UInt(0))

# CHECK: index out of bounds
_ = normalize_index[""](Int.MIN, 10)
# CHECK: index out of bounds
_ = normalize_index[""](Int.MIN, UInt(10))
# CHECK: index out of bounds
_ = normalize_index[""](Int.MAX, 10)
# CHECK: index out of bounds
_ = normalize_index[""](Int.MAX, UInt(10))
# CHECK: index out of bounds
_ = normalize_index[""](Int.MIN, Int.MAX)

# CHECK: index out of bounds
_ = normalize_index[""](UInt.MAX, 10)
# CHECK: index out of bounds
Expand All @@ -54,9 +65,6 @@ def test_out_of_bounds_message():
# CHECK: index out of bounds
_ = normalize_index[""](UInt.MAX, UInt.MAX - 10)

# CHECK: Overflow Error
_ = normalize_index[""](-1, UInt(Int.MAX + 1))


def test_normalize_index():
assert_equal(normalize_index[""](-3, 3), 0)
Expand Down Expand Up @@ -91,6 +99,13 @@ def test_normalize_index():
assert_equal(normalize_index[""](UInt(1), UInt.MAX), 1)
assert_equal(normalize_index[""](UInt.MAX - 5, UInt.MAX), UInt.MAX - 5)

assert_equal(normalize_index[""](-1, Int.MAX), Int.MAX - 1)
assert_equal(normalize_index[""](-10, Int.MAX), Int.MAX - 10)
assert_equal(normalize_index[""](-1, UInt.MAX), UInt.MAX - 1)
assert_equal(normalize_index[""](-10, UInt.MAX), UInt.MAX - 10)
assert_equal(normalize_index[""](-1, UInt(Int.MAX) + 1), UInt(Int.MAX))
assert_equal(normalize_index[""](Int.MIN, UInt(Int.MAX) + 1), 0)


def main():
test_out_of_bounds_message()
Expand Down

0 comments on commit 2581aa0

Please sign in to comment.