Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

accelerate singleton levels #675

Merged
merged 9 commits into from
Dec 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Finch
using BenchmarkTools
using MatrixDepot
using SparseArrays
using Random
include(joinpath(@__DIR__, "../docs/examples/bfs.jl"))
include(joinpath(@__DIR__, "../docs/examples/pagerank.jl"))
include(joinpath(@__DIR__, "../docs/examples/shortest_paths.jl"))
Expand Down Expand Up @@ -245,6 +246,17 @@ function spmv_serial(A, x)
end
end

function spmv_noinit(y, A, x)
@finch begin
for i=_
for j=_
y[i] += A[j, i] * x[j]
end
end
return y
end
end

function spmv_threaded(A, x)
y = Tensor(Dense{Int64}(Element{0.0, Float64}()))
@finch begin
Expand Down Expand Up @@ -302,25 +314,28 @@ end

SUITE["structure"] = BenchmarkGroup()

N = 100_000
N = 1_000_000

SUITE["structure"]["permutation"] = BenchmarkGroup()

A_ref = Tensor(Dense(SparseList(Element(0.0))), fsparse(collect(1:N), collect(1:N), ones(N)))
perm = randperm(N)

A_ref = Tensor(Dense(SparseList(Element(0.0))), fsparse(collect(1:N), perm, ones(N)))

A = Tensor(Dense(SparsePoint(Element(0.0))), A_ref)

x = rand(N)

SUITE["structure"]["permutation"]["SparseList"] = @benchmarkable spmv_serial($A_ref, $x)
SUITE["structure"]["permutation"]["SparsePoint"] = @benchmarkable spmv_serial($A, $x)
SUITE["structure"]["permutation"]["baseline"] = @benchmarkable $x[$perm]

SUITE["structure"]["banded"] = BenchmarkGroup()

A_ref = Tensor(Dense(Sparse(Element(0.0))), N, N)

@finch for i = _, j = _
if abs(i - j) < 2
@finch for j = _, i = _
if j - 2 < i < j + 2
A_ref[i, j] = 1.0
end
end
Expand Down
124 changes: 44 additions & 80 deletions src/tensors/levels/sparse_band_levels.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
SparseBandLevel{[Ti=Int], [Ptr, Idx, Ofs]}(lvl, [dim])
SparseBandLevel{[Ti=Int], [Idx, Ofs]}(lvl, [dim])

Like the [`SparseBlockListLevel`](@ref), but stores only a single block, and fills in zeros.

Expand All @@ -14,10 +14,9 @@
│ ├─[1]: 20.0
│ ├─[3]: 40.0
"""
struct SparseBandLevel{Ti, Ptr<:AbstractVector, Idx<:AbstractVector, Ofs<:AbstractVector, Lvl} <: AbstractLevel
struct SparseBandLevel{Ti, Idx<:AbstractVector, Ofs<:AbstractVector, Lvl} <: AbstractLevel
lvl::Lvl
shape::Ti
ptr::Ptr
idx::Idx
ofs::Ofs
end
Expand All @@ -26,40 +25,39 @@
SparseBandLevel(lvl::Lvl) where {Lvl} = SparseBandLevel{Int}(lvl)
SparseBandLevel(lvl, shape, args...) = SparseBandLevel{typeof(shape)}(lvl, shape, args...)
SparseBandLevel{Ti}(lvl) where {Ti} = SparseBandLevel{Ti}(lvl, zero(Ti))
SparseBandLevel{Ti}(lvl, shape) where {Ti} = SparseBandLevel{Ti}(lvl, shape, postype(lvl)[1], Ti[], postype(lvl)[])
SparseBandLevel{Ti}(lvl::Lvl, shape, ptr::Ptr, idx::Idx, ofs::Ofs) where {Ti, Lvl, Ptr, Idx, Ofs} =
SparseBandLevel{Ti, Ptr, Idx, Ofs, Lvl}(lvl, Ti(shape), ptr, idx, ofs)
SparseBandLevel{Ti}(lvl, shape) where {Ti} = SparseBandLevel{Ti}(lvl, shape, Ti[], postype(lvl)[])
SparseBandLevel{Ti}(lvl::Lvl, shape, idx::Idx, ofs::Ofs) where {Ti, Lvl, Idx, Ofs} =
SparseBandLevel{Ti, Idx, Ofs, Lvl}(lvl, Ti(shape), idx, ofs)

function postype(::Type{SparseBandLevel{Ti, Ptr, Idx, Ofs, Lvl}}) where {Ti, Ptr, Idx, Ofs, Lvl}
function postype(::Type{SparseBandLevel{Ti, Idx, Ofs, Lvl}}) where {Ti, Idx, Ofs, Lvl}
return postype(Lvl)
end

function moveto(lvl::SparseBandLevel{Ti}, device) where {Ti}
lvl_2 = moveto(lvl.lvl, device)
ptr_2 = moveto(lvl.ptr, device)
idx_2 = moveto(lvl.idx, device)
ofs_2 = moveto(lvl.ofs, device)
return SparseBandLevel{Ti}(lvl_2, lvl.shape, ptr_2, idx_2, ofs_2)
return SparseBandLevel{Ti}(lvl_2, lvl.shape, idx_2, ofs_2)

Check warning on line 40 in src/tensors/levels/sparse_band_levels.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/levels/sparse_band_levels.jl#L40

Added line #L40 was not covered by tests
end

Base.summary(lvl::SparseBandLevel) = "SparseBand($(summary(lvl.lvl)))"
similar_level(lvl::SparseBandLevel, fill_value, eltype::Type, dim, tail...) =
SparseBand(similar_level(lvl.lvl, fill_value, eltype, tail...), dim)

pattern!(lvl::SparseBandLevel{Ti}) where {Ti} =
SparseBandLevel{Ti}(pattern!(lvl.lvl), lvl.shape, lvl.ptr, lvl.idx, lvl.ofs)
SparseBandLevel{Ti}(pattern!(lvl.lvl), lvl.shape, lvl.idx, lvl.ofs)

function countstored_level(lvl::SparseBandLevel, pos)
countstored_level(lvl.lvl, lvl.ofs[lvl.ptr[pos + 1]]-1)
countstored_level(lvl.lvl, lvl.ofs[pos + 1]-1)
end

set_fill_value!(lvl::SparseBandLevel{Ti}, init) where {Ti} =
SparseBandLevel{Ti}(set_fill_value!(lvl.lvl, init), lvl.shape, lvl.ptr, lvl.idx, lvl.ofs)
SparseBandLevel{Ti}(set_fill_value!(lvl.lvl, init), lvl.shape, lvl.idx, lvl.ofs)

Base.resize!(lvl::SparseBandLevel{Ti}, dims...) where {Ti} =
SparseBandLevel{Ti}(resize!(lvl.lvl, dims[1:end-1]...), dims[end], lvl.ptr, lvl.idx, lvl.ofs)
SparseBandLevel{Ti}(resize!(lvl.lvl, dims[1:end-1]...), dims[end], lvl.idx, lvl.ofs)

function Base.show(io::IO, lvl::SparseBandLevel{Ti, Ptr, Idx, Ofs, Lvl}) where {Ti, Ptr, Idx, Ofs, Lvl}
function Base.show(io::IO, lvl::SparseBandLevel{Ti, Idx, Ofs, Lvl}) where {Ti, Idx, Ofs, Lvl}
if get(io, :compact, false)
print(io, "SparseBand(")
else
Expand All @@ -72,8 +70,6 @@
if get(io, :compact, false)
print(io, "…")
else
show(io, lvl.ptr)
print(io, ", ")
show(io, lvl.idx)
print(io, ", ")
show(io, lvl.ofs)
Expand All @@ -87,37 +83,34 @@
function labelled_children(fbr::SubFiber{<:SparseBandLevel})
lvl = fbr.lvl
pos = fbr.pos
pos + 1 > length(lvl.ptr) && return []
res = []
for r = lvl.ptr[pos]:lvl.ptr[pos + 1] - 1
i = lvl.idx[r]
qos = lvl.ofs[r]
l = lvl.ofs[r + 1] - lvl.ofs[r]
for qos = lvl.ofs[r]:lvl.ofs[r + 1] - 1
push!(res, LabelledTree(cartesian_label([range_label() for _ = 1:ndims(fbr) - 1]..., i - (lvl.ofs[r + 1] - 1) + qos), SubFiber(lvl.lvl, qos)))
end
for qos = lvl.ofs[pos]:lvl.ofs[pos + 1] - 1
i = lvl.idx[pos] - lvl.ofs[pos + 1] + qos + 1
push!(res, LabelledTree(cartesian_label([range_label() for _ = 1:ndims(fbr) - 1]..., i), SubFiber(lvl.lvl, qos)))

Check warning on line 89 in src/tensors/levels/sparse_band_levels.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/levels/sparse_band_levels.jl#L87-L89

Added lines #L87 - L89 were not covered by tests
end
res
end

@inline level_ndims(::Type{<:SparseBandLevel{Ti, Ptr, Idx, Ofs, Lvl}}) where {Ti, Ptr, Idx, Ofs, Lvl} = 1 + level_ndims(Lvl)
@inline level_ndims(::Type{<:SparseBandLevel{Ti, Idx, Ofs, Lvl}}) where {Ti, Idx, Ofs, Lvl} = 1 + level_ndims(Lvl)
@inline level_size(lvl::SparseBandLevel) = (level_size(lvl.lvl)..., lvl.shape)
@inline level_axes(lvl::SparseBandLevel) = (level_axes(lvl.lvl)..., Base.OneTo(lvl.shape))
@inline level_eltype(::Type{<:SparseBandLevel{Ti, Ptr, Idx, Ofs, Lvl}}) where {Ti, Ptr, Idx, Ofs, Lvl} = level_eltype(Lvl)
@inline level_fill_value(::Type{<:SparseBandLevel{Ti, Ptr, Idx, Ofs, Lvl}}) where {Ti, Ptr, Idx, Ofs, Lvl} = level_fill_value(Lvl)
data_rep_level(::Type{<:SparseBandLevel{Ti, Ptr, Idx, Ofs, Lvl}}) where {Ti, Ptr, Idx, Ofs, Lvl} = SparseData(data_rep_level(Lvl))
@inline level_eltype(::Type{<:SparseBandLevel{Ti, Idx, Ofs, Lvl}}) where {Ti, Idx, Ofs, Lvl} = level_eltype(Lvl)
@inline level_fill_value(::Type{<:SparseBandLevel{Ti, Idx, Ofs, Lvl}}) where {Ti, Idx, Ofs, Lvl} = level_fill_value(Lvl)
data_rep_level(::Type{<:SparseBandLevel{Ti, Idx, Ofs, Lvl}}) where {Ti, Idx, Ofs, Lvl} = SparseData(data_rep_level(Lvl))

Check warning on line 99 in src/tensors/levels/sparse_band_levels.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/levels/sparse_band_levels.jl#L99

Added line #L99 was not covered by tests

(fbr::AbstractFiber{<:SparseBandLevel})() = fbr
function (fbr::SubFiber{<:SparseBandLevel})(idxs...)
isempty(idxs) && return fbr
lvl = fbr.lvl
p = fbr.pos
r = lvl.ptr[p] + searchsortedfirst(@view(lvl.idx[lvl.ptr[p]:lvl.ptr[p + 1] - 1]), idxs[end]) - 1
r < lvl.ptr[p + 1] || return fill_value(fbr)
q = lvl.ofs[r + 1] - 1 - lvl.idx[r] + idxs[end]
q >= lvl.ofs[r] || return fill_value(fbr)
fbr_2 = SubFiber(lvl.lvl, q)
return fbr_2(idxs[1:end-1]...)
pos = fbr.pos
start = lvl.idx[pos] - lvl.ofs[pos + 1] + lvl.ofs[pos] + 1
stop = lvl.idx[pos]
if start <= idxs[end] <= stop
qos = lvl.ofs[pos] + idxs[end] - start
fbr_2 = SubFiber(lvl.lvl, qos)
return fbr_2(idxs[1:end-1]...)
end
return fill_value(fbr)
end

mutable struct VirtualSparseBandLevel <: AbstractVirtualLevel
Expand All @@ -130,7 +123,6 @@
ros_fill
ros_stop
dirty
ptr
idx
ofs
prev_pos
Expand All @@ -149,33 +141,30 @@
postype(lvl::VirtualSparseBandLevel) = postype(lvl.lvl)


function virtualize(ctx, ex, ::Type{SparseBandLevel{Ti, Ptr, Idx, Ofs, Lvl}}, tag=:lvl) where {Ti, Ptr, Idx, Ofs, Lvl}
function virtualize(ctx, ex, ::Type{SparseBandLevel{Ti, Idx, Ofs, Lvl}}, tag=:lvl) where {Ti, Idx, Ofs, Lvl}
sym = freshen(ctx, tag)
shape = value(:($sym.shape), Int)
qos_fill = freshen(ctx, sym, :_qos_fill)
qos_stop = freshen(ctx, sym, :_qos_stop)
ros_fill = freshen(ctx, sym, :_ros_fill)
ros_stop = freshen(ctx, sym, :_ros_stop)
dirty = freshen(ctx, sym, :_dirty)
ptr = freshen(ctx, tag, :_ptr)
idx = freshen(ctx, tag, :_idx)
ofs = freshen(ctx, tag, :_ofs)
push_preamble!(ctx, quote
$sym = $ex
$ptr = $sym.ptr
$idx = $sym.idx
$ofs = $sym.ofs
end)
prev_pos = freshen(ctx, sym, :_prev_pos)
lvl_2 = virtualize(ctx, :($sym.lvl), Lvl, sym)
VirtualSparseBandLevel(lvl_2, sym, Ti, shape, qos_fill, qos_stop, ros_fill, ros_stop, dirty, ptr, idx, ofs, prev_pos)
VirtualSparseBandLevel(lvl_2, sym, Ti, shape, qos_fill, qos_stop, ros_fill, ros_stop, dirty, idx, ofs, prev_pos)
end
function lower(ctx::AbstractCompiler, lvl::VirtualSparseBandLevel, ::DefaultStyle)
quote
$SparseBandLevel{$(lvl.Ti)}(
$(ctx(lvl.lvl)),
$(ctx(lvl.shape)),
$(lvl.ptr),
$(lvl.idx),
$(lvl.ofs),
)
Expand All @@ -199,19 +188,15 @@
virtual_level_fill_value(lvl::VirtualSparseBandLevel) = virtual_level_fill_value(lvl.lvl)

function virtual_moveto_level(ctx::AbstractCompiler, lvl::VirtualSparseBandLevel, arch)
ptr_2 = freshen(ctx, lvl.ptr)
tbl_2 = freshen(ctx, lvl.tbl)
ofs_2 = freshen(ctx, lvl.ofs)
push_preamble!(ctx, quote
$ptr_2 = $(lvl.ptr)
$tbl_2 = $(lvl.tbl)
$ofs_2 = $(lvl.ofs)
$(lvl.ptr) = $moveto($(lvl.ptr), $(ctx(arch)))
$(lvl.tbl) = $moveto($(lvl.tbl), $(ctx(arch)))
$(lvl.ofs) = $moveto($(lvl.ofs), $(ctx(arch)))
end)
push_epilogue!(ctx, quote
$(lvl.ptr) = $ptr_2
$(lvl.tbl) = $tbl_2
$(lvl.ofs) = $ofs_2
end)
Expand All @@ -224,8 +209,6 @@
push_preamble!(ctx, quote
$(lvl.qos_fill) = $(Tp(0))
$(lvl.qos_stop) = $(Tp(0))
$(lvl.ros_fill) = $(Tp(0))
$(lvl.ros_stop) = $(Tp(0))
Finch.resize_if_smaller!($(lvl.ofs), 1)
$(lvl.ofs)[1] = 1
end)
Expand All @@ -242,26 +225,25 @@
pos_start = ctx(cache!(ctx, :p_start, pos_start))
pos_stop = ctx(cache!(ctx, :p_start, pos_stop))
return quote
Finch.resize_if_smaller!($(lvl.ptr), $pos_stop + 1)
Finch.fill_range!($(lvl.ptr), 0, $pos_start + 1, $pos_stop + 1)
Finch.resize_if_smaller!($(lvl.idx), $pos_stop)
Finch.fill_range!($(lvl.idx), 1, $pos_start, $pos_stop)
Finch.resize_if_smaller!($(lvl.ofs), $pos_stop + 1)
Finch.fill_range!($(lvl.ofs), 0, $pos_start + 1, $pos_stop + 1)
end
end

function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseBandLevel, pos_stop)
p = freshen(ctx, :p)
Tp = postype(lvl)
pos_stop = ctx(cache!(ctx, :pos_stop, simplify(ctx, pos_stop)))
ros_stop = freshen(ctx, :ros_stop)
qos_stop = freshen(ctx, :qos_stop)
push_preamble!(ctx, quote
resize!($(lvl.ptr), $pos_stop + 1)
resize!($(lvl.idx), $pos_stop)
resize!($(lvl.ofs), $pos_stop + 1)
for $p = 2:($pos_stop + 1)
$(lvl.ptr)[$p] += $(lvl.ptr)[$p - 1]
$(lvl.ofs)[$p] += $(lvl.ofs)[$p - 1]
end
$ros_stop = $(lvl.ptr)[$pos_stop + 1] - 1
resize!($(lvl.idx), $ros_stop)
resize!($(lvl.ofs), $ros_stop + 1)
$qos_stop = $(lvl.ofs)[$ros_stop + 1] - $(Tp(1))
$qos_stop = $(lvl.ofs)[$pos_stop + 1] - $(Tp(1))
end)
lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_stop))
return lvl
Expand All @@ -285,19 +267,10 @@
arr = fbr,
body = Thunk(
preamble = quote
$my_r = $(lvl.ptr)[$(ctx(pos))]
$my_r_stop = $(lvl.ptr)[$(ctx(pos)) + $(Tp(1))] - 1
if $my_r <= $my_r_stop
$my_i1 = $(lvl.idx)[$my_r]
$my_q_stop = $(lvl.ofs)[$my_r + $(Tp(1))]
$my_i_start = $my_i1 - ($my_q_stop - $(lvl.ofs)[$my_r] - 1)
$my_q_ofs = $my_q_stop - $my_i1 - $(Tp(1))
else
$my_i_start = $(Ti(1))
$my_i1 = $(Ti(0))
$my_q_stop = $(Ti(0))
$my_q = $(Ti(0))
end
$my_i1 = $(lvl.idx)[$(ctx(pos))]
$my_q_stop = $(lvl.ofs)[$(ctx(pos)) + $(Tp(1))]
$my_i_start = $my_i1 - ($my_q_stop - $(lvl.ofs)[$(ctx(pos))] - 1)
$my_q_ofs = $my_q_stop - $my_i1 - $(Tp(1))
end,
body = (ctx) -> Sequence([
Phase(
Expand Down Expand Up @@ -346,7 +319,6 @@
arr = fbr,
body = Thunk(
preamble = quote
$ros = $ros_fill
$qos = $qos_fill + 1
$qos_set = $qos_fill
$my_i_prev = $(Ti(-1))
Expand Down Expand Up @@ -394,24 +366,16 @@
),
epilogue = quote
if $my_i_prev > 0
$ros += 1
if $ros > $ros_stop
$ros_stop = max($ros_stop << 1, 1)
Finch.resize_if_smaller!($(lvl.idx), $ros_stop)
Finch.resize_if_smaller!($(lvl.ofs), $ros_stop + 1)
end
$qos = $qos_set
$(lvl.idx)[$(ros)] = $my_i_set
$(lvl.ofs)[$(ros) + 1] = $qos + 1
$(lvl.idx)[$(ctx(pos))] = $my_i_set
$(lvl.ofs)[$(ctx(pos)) + 1] = $my_i_set - $my_i_prev + 1
$(if issafe(get_mode_flag(ctx))
quote
$(lvl.prev_pos) = $(ctx(pos))
end
end)
$qos_fill = $qos
end
$(lvl.ptr)[$(ctx(pos)) + 1] += $ros - $ros_fill
$ros_fill = $ros
end
)
)
Expand Down
1 change: 1 addition & 0 deletions src/tensors/levels/sparse_dict_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@
function virtual_moveto_level(ctx::AbstractCompiler, lvl::VirtualSparseDictLevel, arch)
ptr_2 = freshen(ctx, lvl.ptr)
idx_2 = freshen(ctx, lvl.idx)
tbl_2 = freshen(ctx, lvl.tbl_2)

Check warning on line 307 in src/tensors/levels/sparse_dict_levels.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/levels/sparse_dict_levels.jl#L307

Added line #L307 was not covered by tests
push_preamble!(ctx, quote
$tbl_2 = $(lvl.tbl)
$(lvl.tbl) = $moveto($(lvl.tbl), $(ctx(arch)))
Expand Down
Loading
Loading