Skip to content

Commit

Permalink
Merge pull request #6 from SciML/sorted
Browse files Browse the repository at this point in the history
Sorted
  • Loading branch information
ChrisRackauckas authored Jan 25, 2024
2 parents 9d539e6 + 591b366 commit fc6353b
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 94 deletions.
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,25 @@ searchsortedfirstcorrelated(v::AbstractVector{T}, x, guess::T)

An accelerated `findfirst` on sorted vectors using a bracketed search. Requires a `guess`
to start the search from.


Some benchmarks:
```julia
julia> x = rand(Int, 2048); s = sort(x);

julia> @btime findfirst(==($x[1011]), $x)
266.427 ns (0 allocations: 0 bytes)
1011

julia> @btime FindFirstFunctions.findfirstequal($x[1011], $x)
67.502 ns (0 allocations: 0 bytes)
1011

julia> @btime searchsortedfirst($s, $s[1011])
8.897 ns (0 allocations: 0 bytes)
1011

julia> @btime FindFirstFunctions.findfirstsortedequal($s[1011], $s)
10.896 ns (0 allocations: 0 bytes)
1011
```
212 changes: 119 additions & 93 deletions src/FindFirstFunctions.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,69 @@
module FindFirstFunctions

function _findfirstequal(vpivot::Int64, ptr::Ptr{Int64}, len::Int64)
Base.llvmcall(("""
declare i8 @llvm.cttz.i8(i8, i1);
define i64 @entry(i64 %0, i64 %1, i64 %2) #0 {
top:
%ivars = inttoptr i64 %1 to i64*
%btmp = insertelement <8 x i64> undef, i64 %0, i64 0
%var = shufflevector <8 x i64> %btmp, <8 x i64> undef, <8 x i32> zeroinitializer
%lenm7 = add nsw i64 %2, -7
%dosimditer = icmp ugt i64 %2, 7
br i1 %dosimditer, label %L9.lr.ph, label %L32
L9.lr.ph:
%len8 = and i64 %2, 9223372036854775800
br label %L9
L9:
%i = phi i64 [ 0, %L9.lr.ph ], [ %vinc, %L30 ]
%ivarsi = getelementptr inbounds i64, i64* %ivars, i64 %i
%vpvi = bitcast i64* %ivarsi to <8 x i64>*
%v = load <8 x i64>, <8 x i64>* %vpvi, align 8
%m = icmp eq <8 x i64> %v, %var
%mu = bitcast <8 x i1> %m to i8
%matchnotfound = icmp eq i8 %mu, 0
br i1 %matchnotfound, label %L30, label %L17
L17:
%tz8 = call i8 @llvm.cttz.i8(i8 %mu, i1 true)
%tz64 = zext i8 %tz8 to i64
%vis = add nuw i64 %i, %tz64
br label %common.ret
common.ret:
%retval = phi i64 [ %vis, %L17 ], [ -1, %L32 ], [ %si, %L51 ], [ -1, %L67 ]
ret i64 %retval
L30:
%vinc = add nuw nsw i64 %i, 8
%continue = icmp slt i64 %vinc, %lenm7
br i1 %continue, label %L9, label %L32
L32:
%cumi = phi i64 [ 0, %top ], [ %len8, %L30 ]
%done = icmp eq i64 %cumi, %2
br i1 %done, label %common.ret, label %L51
L51:
%si = phi i64 [ %inc, %L67 ], [ %cumi, %L32 ]
%spi = getelementptr inbounds i64, i64* %ivars, i64 %si
%svi = load i64, i64* %spi, align 8
%match = icmp eq i64 %svi, %0
br i1 %match, label %common.ret, label %L67
L67:
%inc = add i64 %si, 1
%dobreak = icmp eq i64 %inc, %2
br i1 %dobreak, label %common.ret, label %L51
}
attributes #0 = { alwaysinline }
""", "entry"), Int64, Tuple{Int64,Ptr{Int64},Int64}, vpivot, ptr,
len)
end

"""
findfirstequal(x::Int64,A::DenseVector{Int64})
Expand All @@ -8,71 +72,33 @@ Finds the first value in `A` equal to `x`
findfirstequal(vpivot, ivars) = findfirst(isequal(vpivot), ivars)
function findfirstequal(vpivot::Int64, ivars::DenseVector{Int64})
GC.@preserve ivars begin
ret = Base.llvmcall(("""
declare i8 @llvm.cttz.i8(i8, i1);
define i64 @entry(i64 %0, i64 %1, i64 %2) #0 {
top:
%ivars = inttoptr i64 %1 to i64*
%btmp = insertelement <8 x i64> undef, i64 %0, i64 0
%var = shufflevector <8 x i64> %btmp, <8 x i64> undef, <8 x i32> zeroinitializer
%lenm7 = add nsw i64 %2, -7
%dosimditer = icmp ugt i64 %2, 7
br i1 %dosimditer, label %L9.lr.ph, label %L32
L9.lr.ph:
%len8 = and i64 %2, 9223372036854775800
br label %L9
L9:
%i = phi i64 [ 0, %L9.lr.ph ], [ %vinc, %L30 ]
%ivarsi = getelementptr inbounds i64, i64* %ivars, i64 %i
%vpvi = bitcast i64* %ivarsi to <8 x i64>*
%v = load <8 x i64>, <8 x i64>* %vpvi, align 8
%m = icmp eq <8 x i64> %v, %var
%mu = bitcast <8 x i1> %m to i8
%matchnotfound = icmp eq i8 %mu, 0
br i1 %matchnotfound, label %L30, label %L17
L17:
%tz8 = call i8 @llvm.cttz.i8(i8 %mu, i1 true)
%tz64 = zext i8 %tz8 to i64
%vis = add nuw i64 %i, %tz64
br label %common.ret
common.ret:
%retval = phi i64 [ %vis, %L17 ], [ -1, %L32 ], [ %si, %L51 ], [ -1, %L67 ]
ret i64 %retval
L30:
%vinc = add nuw nsw i64 %i, 8
%continue = icmp slt i64 %vinc, %lenm7
br i1 %continue, label %L9, label %L32
L32:
%cumi = phi i64 [ 0, %top ], [ %len8, %L30 ]
%done = icmp eq i64 %cumi, %2
br i1 %done, label %common.ret, label %L51
L51:
%si = phi i64 [ %inc, %L67 ], [ %cumi, %L32 ]
%spi = getelementptr inbounds i64, i64* %ivars, i64 %si
%svi = load i64, i64* %spi, align 8
%match = icmp eq i64 %svi, %0
br i1 %match, label %common.ret, label %L67
L67:
%inc = add i64 %si, 1
%dobreak = icmp eq i64 %inc, %2
br i1 %dobreak, label %common.ret, label %L51
}
attributes #0 = { alwaysinline }
""", "entry"), Int64, Tuple{Int64,Ptr{Int64},Int64}, vpivot, pointer(ivars),
length(ivars))
ret = _findfirstequal(vpivot, pointer(ivars), length(ivars))
end
ret < 0 ? nothing : ret + 1
end

"""
findfirstsortedequal(vars::DenseVector{Int64}, var::Int64)::Union{Int64,Nothing}
Note that this differs from `searchsortedfirst` by returning `nothing` when absent.
"""
function findfirstsortedequal(var::Int64, vars::DenseVector{Int64},
::Val{basecase}=Val(16)) where {basecase}
len = length(vars)
offset = 0
@inbounds while len > basecase
half = len >>> 1 # half on left, len - half on right
offset = ifelse(vars[offset+half+1] <= var, half + offset, offset)
len = len - half
end
# maybe occurs in vars[offset+1:offset+len]
GC.@preserve vars begin
ret = _findfirstequal(var, pointer(vars) + 8offset, len)
end
ret < 0 ? nothing : ret + offset + 1
end


"""
bracketstrictlymontonic(v, x, guess; lt=<comparison>, by=<transform>, rev=false)
Expand All @@ -94,36 +120,36 @@ this function would be the index returned by the previous call to `searchsorted`
See `Base.sort!` for an explanation of the keyword arguments `by`, `lt` and `rev`.
"""
function bracketstrictlymontonic(v::AbstractVector,
x,
guess::T,
o::Base.Order.Ordering)::NTuple{2, keytype(v)} where {T <: Integer}
bottom = firstindex(v)
top = lastindex(v)
if guess < bottom || guess > top
return bottom, top
# # NOTE: for cache efficiency in repeated calls, we avoid accessing the first and last elements of `v`
# # on each call to this function. This should only result in significant slow downs for calls with
# # out-of-bounds values of `x` *and* bad `guess`es.
# elseif lt(o, x, v[bottom])
# return bottom, bottom
# elseif lt(o, v[top], x)
# return top, top
x,
guess::T,
o::Base.Order.Ordering)::NTuple{2,keytype(v)} where {T<:Integer}
bottom = firstindex(v)
top = lastindex(v)
if guess < bottom || guess > top
return bottom, top
# # NOTE: for cache efficiency in repeated calls, we avoid accessing the first and last elements of `v`
# # on each call to this function. This should only result in significant slow downs for calls with
# # out-of-bounds values of `x` *and* bad `guess`es.
# elseif lt(o, x, v[bottom])
# return bottom, bottom
# elseif lt(o, v[top], x)
# return top, top
else
u = T(1)
lo, hi = guess, min(guess + u, top)
@inbounds if Base.Order.lt(o, x, v[lo])
while lo > bottom && Base.Order.lt(o, x, v[lo])
lo, hi = max(bottom, lo - u), lo
u += u
end
else
u = T(1)
lo, hi = guess, min(guess + u, top)
@inbounds if Base.Order.lt(o, x, v[lo])
while lo > bottom && Base.Order.lt(o, x, v[lo])
lo, hi = max(bottom, lo - u), lo
u += u
end
else
while hi < top && !Base.Order.lt(o, x, v[hi])
lo, hi = hi, min(top, hi + u)
u += u
end
end
while hi < top && !Base.Order.lt(o, x, v[hi])
lo, hi = hi, min(top, hi + u)
u += u
end
end
return lo, hi
end
return lo, hi
end

"""
Expand All @@ -133,13 +159,13 @@ An accelerated `findfirst` on sorted vectors using a bracketed search. Requires
to start the search from.
"""
function searchsortedfirstcorrelated(v::AbstractVector, x, guess)
lo, hi = bracketstrictlymontonic(v, x, guess, Base.Order.Forward)
searchsortedfirst(v, x, lo, hi, Base.Order.Forward)
lo, hi = bracketstrictlymontonic(v, x, guess, Base.Order.Forward)
searchsortedfirst(v, x, lo, hi, Base.Order.Forward)
end

function searchsortedlastcorrelated(v::AbstractVector, x, guess)
lo, hi = bracketstrictlymontonic(v, x, guess, Base.Order.Forward)
searchsortedlast(v, x, lo, hi, Base.Order.Forward)
lo, hi = bracketstrictlymontonic(v, x, guess, Base.Order.Forward)
searchsortedlast(v, x, lo, hi, Base.Order.Forward)
end

searchsortedfirstcorrelated(r::AbstractRange, x, _) = searchsortedfirst(r, x)
Expand Down
7 changes: 6 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@ using Test

for n = 0:128
x = unique!(rand(Int, n))
s = sort(x)
for i = eachindex(x)
@test FindFirstFunctions.findfirstequal(x[i], x) == i
@test FindFirstFunctions.findfirstequal(s[i], s) == i
@test FindFirstFunctions.findfirstsortedequal(s[i], s) == i
end
if length(x) > 0
@test FindFirstFunctions.findfirstequal(x[begin], @view(x[begin:end])) === 1
@test FindFirstFunctions.findfirstequal(x[begin], @view(x[begin+1:end])) === nothing
@test FindFirstFunctions.findfirstequal(x[end], @view(x[begin:end-1])) === nothing
end
y = rand(Int)
@test FindFirstFunctions.findfirstequal(y, x) === findfirst(==(y), x)
ff = findfirst(==(y), x)
@test FindFirstFunctions.findfirstequal(y, x) === ff
ff === nothing && @test FindFirstFunctions.findfirstsortedequal(y, x) === nothing
end

end
Expand Down

0 comments on commit fc6353b

Please sign in to comment.