Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep Galley Plans Per Approximate Sparsity Pattern #679

Merged
merged 28 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5896d1a
add issimilar and get_cannonical_stats
Dec 19, 2024
7b69f89
cleanup
Dec 19, 2024
6761338
small fixes
Dec 19, 2024
4bbf23f
fix stats construction
Dec 19, 2024
092579c
last fix, definitely true this time
Dec 19, 2024
79d836d
fix NaiveStats constructor, again
Dec 20, 2024
2ed581d
small change to verbose passing
Dec 20, 2024
e835a12
fix deferred hash issue #664
Dec 20, 2024
4de8a9c
fix deferred equality check
Dec 20, 2024
85ac56e
Merge branch 'main' into kbd-make-galley-adaptive-to-inputs
willow-ahrens Dec 28, 2024
c53f9fb
Merge branch 'main' into kbd-make-galley-adaptive-to-inputs
willow-ahrens Dec 28, 2024
354db43
Merge remote-tracking branch 'origin/main' into kbd-make-galley-adapt…
willow-ahrens Dec 28, 2024
b5b2831
more accurate benchmark
willow-ahrens Dec 28, 2024
0c8ab41
add evaluation count to benchmark setup
willow-ahrens Dec 28, 2024
bff9d0d
rename GalleyExecutor to AdaptiveExecutor
Jan 2, 2025
5aa6d01
warn when the tag argument is given to the AdaptiveExecutor
Jan 2, 2025
816d974
remove with_scheduler issue
Jan 2, 2025
ffa437f
bug fix
Jan 2, 2025
cf19c69
lowering default threshold & warn on tag argument
Jan 2, 2025
d24a6af
update high-level benchmarks to set scheduler in setup
Jan 2, 2025
e31b572
Merge branch 'main' into kbd-make-galley-adaptive-to-inputs
kylebd99 Jan 2, 2025
5fd941b
small fix
Jan 3, 2025
f52839f
add compute overhead check
Jan 3, 2025
595f7b6
temporarily remove galley scheduler
Jan 3, 2025
910b033
add galley_scheduler back
Jan 6, 2025
1dbc64c
drop the warn on tag arg
Jan 6, 2025
bd2a62d
Merge branch 'main' into kbd-make-galley-adaptive-to-inputs
willow-ahrens Jan 6, 2025
fb7d3ca
Merge branch 'main' into kbd-make-galley-adaptive-to-inputs
willow-ahrens Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 102 additions & 78 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ end
if "FINCH_BENCHMARK_ARGS" in keys(ENV)
ARGS = split(ENV["FINCH_BENCHMARK_ARGS"], " ")
end

parsed_args = parse_args(ARGS, s)

include(joinpath(@__DIR__, "../docs/examples/bfs.jl"))
Expand All @@ -47,90 +46,115 @@ for (scheduler_name, scheduler) in [
"default_scheduler" => Finch.default_scheduler(),
"galley_scheduler" => Finch.galley_scheduler(),
]
Finch.with_scheduler(scheduler) do
let
A = Tensor(Dense(Sparse(Element(0.0))), fsprand(10000, 10000, 0.01))
SUITE["high-level"]["permutedims(Dense(Sparse()))"][scheduler_name] = @benchmarkable(permutedims($A, (2, 1)))
end
let
A = Tensor(Dense(Sparse(Element(0.0))), fsprand(10000, 10000, 0.01))
SUITE["high-level"]["permutedims(Dense(Sparse()))"][scheduler_name] = @benchmarkable(permutedims($A, (2, 1)), setup = (Finch.set_scheduler!($scheduler)))
end

let
A = Tensor(Dense(Dense(Element(0.0))), rand(10000, 10000))
SUITE["high-level"]["permutedims(Dense(Dense()))"][scheduler_name] = @benchmarkable(permutedims($A, (2, 1)))
end
let
A = Tensor(Dense(Dense(Element(0.0))), rand(10000, 10000))
SUITE["high-level"]["permutedims(Dense(Dense()))"][scheduler_name] = @benchmarkable(permutedims($A, (2, 1)), setup = (Finch.set_scheduler!($scheduler)))
end

let
k = Ref(0.0)
x = rand(1)
y = rand(1)
SUITE["high-level"]["einsum_spmv_compile_overhead"][scheduler_name] = @benchmarkable(
begin
A, x, y = (A, $x, $y)
@einsum y[i] += A[i, j] * x[j]
end,
setup = (A = Tensor(Dense(SparseList(Element($k[] += 1))), fsprand(1, 1, 1)))
)
end
let
k = Ref(0.0)
x = rand(1)
y = rand(1)
SUITE["high-level"]["einsum_spmv_compile_overhead"][scheduler_name] = @benchmarkable(
begin
A, x, y = (A, $x, $y)
@einsum y[i] += A[i, j] * x[j]
end,
setup = (Finch.set_scheduler!($scheduler); A = Tensor(Dense(SparseList(Element($k[] += 1))), fsprand(1, 1, 1)))
)
end

let
N = 10
P = 0.0001
C = 16.0
SUITE["high-level"]["einsum_matmul_adaptive_overhead"][scheduler_name] = @benchmarkable(
begin
@einsum C[i, j] += A[i, k] * B[k, j]
end,
setup = begin
(N, P, C) = ($N, $P, $C)
n = floor(Int, N * C^(rand()))
m = floor(Int, N * C^(rand()))
l = floor(Int, N * C^(rand()))
p = floor(Int, P * C^(rand()))
q = floor(Int, P * C^(rand()))
A = fsprand(n, l, p)
B = fsprand(l, m, q)
end,
evals = 1
)
let
N = 100000
function generate_kernel_defs()
for nnz1 in reverse([4, 4^2, 4^3, 4^4])
for nnz2 in reverse([4, 4^2, 4^3, 4^4])
for nnz3 in reverse([4, 4^2, 4^3, 4^4])
A = lazy(fsprand(N, N, nnz1))
B = lazy(fsprand(N, N, nnz2))
C = lazy(fsprand(N, N, nnz3))
compute(A * B * C)
end
end
end
end

let
A = Tensor(Dense(SparseList(Element(0.0))), fsprand(1, 1, 1))
x = rand(1)
SUITE["high-level"]["einsum_spmv_call_overhead"][scheduler_name] = @benchmarkable(
begin
A, x = ($A, $x)
@einsum y[i] += A[i, j] * x[j]
end,
)
end
SUITE["high-level"]["matchain_adaptive_overhead"][scheduler_name] = @benchmarkable(
begin
compute(A * B * C)
end,
setup = begin
Finch.set_scheduler!($scheduler)
N = $N
$generate_kernel_defs()
A = lazy(fsprand(N, N, 4))
B = lazy(fsprand(N, N, 4))
C = lazy(fsprand(N, N, 4))
end,
evals = 1
)
end

let
N = 1_000
K = 1_000
p = 0.001
A = Tensor(Dense(Dense(Element(0.0))), rand(N, K))
B = Tensor(Dense(Dense(Element(0.0))), rand(K, N))
M = Tensor(Dense(SparseList(Element(0.0))), fsprand(N, N, p))

SUITE["high-level"]["sddmm_fused"][scheduler_name] = @benchmarkable(
begin
M = lazy($M)
A = lazy($A)
B = lazy($B)
compute(M .* (A * B))
end,
)

SUITE["high-level"]["sddmm_unfused"][scheduler_name] = @benchmarkable(
begin
M = $M
A = $A
B = $B
M .* (A * B)
end,
)
end
let
A = Tensor(Dense(SparseList(Element(0.0))), fsprand(1, 1, 1))
x = rand(1)
y = rand(1)
SUITE["high-level"]["einsum_spmv_call_overhead"][scheduler_name] = @benchmarkable(
begin
A, x, y = ($A, $x, $y)
@einsum y[i] += A[i, j] * x[j]
end,
setup = (Finch.set_scheduler!($scheduler);),
evals = 1
)
end

let
A = Tensor(Dense(SparseList(Element(0.0))), fsprand(1, 1, 1))
x = rand(1)
SUITE["high-level"]["compute_spmv_call_overhead"][scheduler_name] = @benchmarkable(
begin
A, x = (lazy($A), lazy($x))
compute(A * x)
end,
setup = (Finch.set_scheduler!($scheduler);),
evals = 1
)
end

let
N = 1_000
K = 1_000
p = 0.001
A = Tensor(Dense(Dense(Element(0.0))), rand(N, K))
B = Tensor(Dense(Dense(Element(0.0))), rand(K, N))
M = Tensor(Dense(SparseList(Element(0.0))), fsprand(N, N, p))

SUITE["high-level"]["sddmm_fused"][scheduler_name] = @benchmarkable(
begin
M = lazy($M)
A = lazy($A)
B = lazy($B)
compute(M .* (A * B))
end,
setup = (Finch.set_scheduler!($scheduler)),
)

SUITE["high-level"]["sddmm_unfused"][scheduler_name] = @benchmarkable(
begin
M = $M
A = $A
B = $B
M .* (A * B)
end,
setup = (Finch.set_scheduler!($scheduler)),
)
end
end


Expand Down
2 changes: 1 addition & 1 deletion src/Finch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ include("interface/einsum.jl")
include("Galley/Galley.jl")
using .Galley

export galley_scheduler
export galley_scheduler, GalleyOptimizer, AdaptiveExecutorCode, AdaptiveExecutor

@deprecate default fill_value
@deprecate redefault! set_fill_value!
Expand Down
4 changes: 2 additions & 2 deletions src/FinchLogic/nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ function Base.:(==)(a::LogicNode, b::LogicNode)
elseif a.kind === immediate
return b.kind === immediate && a.val === b.val
elseif a.kind === deferred
return b.kind === deferred && a.val === b.val && a.type === b.type
return b.kind === deferred && a.ex === b.ex && a.type === b.type
elseif a.kind === field
return b.kind === field && a.name == b.name
elseif a.kind === alias
Expand All @@ -370,7 +370,7 @@ function Base.hash(a::LogicNode, h::UInt)
elseif istree(a)
return hash(a.kind, hash(a.children, h))
elseif a.kind === deferred
return hash(a.kind, hash(a.val, hash(a.type, h)))
return hash(a.kind, hash(a.ex, hash(a.type, h)))
else
error("unimplemented")
end
Expand Down
81 changes: 79 additions & 2 deletions src/Galley/FinchCompat/executor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,93 @@
julia_prgm
end

function Finch.set_options(ctx::GalleyOptimizer; estimator=DCStats)
function Finch.set_options(ctx::GalleyOptimizer; estimator=DCStats, verbose=false)
ctx.estimator=estimator
ctx.verbose=verbose
return ctx
end

"""
get_stats_dict(ctx::GalleyOptimizer, prgm)

Returns a dictionary mapping the location of input tensors in the program to their statistics objects.
"""
function get_stats_dict(ctx::GalleyOptimizer, prgm)
deferred_prgm = Finch.defer_tables(:prgm, prgm)
expr_stats_dict = Dict()
for node in PostOrderDFS(deferred_prgm)
if node.kind == table
expr_stats_dict[node.tns.ex] = ctx.estimator(node.tns.imm, [i.name for i in node.idxs])
end
end
return expr_stats_dict
end

"""
AdaptiveExecutor(ctx::GalleyOptimizer, verbose=false)

Executes a logic program by compiling it with the given compiler `ctx`. Compiled
codes are cached for each program structure. It first checks the cache for a plan that
was compiled for similar inputs and only compiles if it doesn't find one.
"""

@kwdef struct AdaptiveExecutor
ctx::GalleyOptimizer
threshold
verbose
end

Base.:(==)(a::AdaptiveExecutor, b::AdaptiveExecutor) = a.ctx == b.ctx && a.threshold == b.threshold && a.verbose == b.verbose
Base.hash(a::AdaptiveExecutor, h::UInt) = hash(AdaptiveExecutor, hash(a.ctx, hash(a.threshold, hash(a.verbose, h))))

Check warning on line 81 in src/Galley/FinchCompat/executor.jl

View check run for this annotation

Codecov / codecov/patch

src/Galley/FinchCompat/executor.jl#L80-L81

Added lines #L80 - L81 were not covered by tests

AdaptiveExecutor(ctx::GalleyOptimizer; threshold = 2, verbose = false) = AdaptiveExecutor(ctx, threshold, verbose)
function Finch.set_options(ctx::AdaptiveExecutor; threshold = 2, verbose = ctx.verbose, tag=:global, kwargs...)
AdaptiveExecutor(Finch.set_options(ctx.ctx; verbose=verbose, kwargs...), threshold, verbose)
end

galley_codes = Dict()
function (ctx::AdaptiveExecutor)(prgm)
cur_stats_dict = get_stats_dict(ctx.ctx, prgm)
stats_list = get!(galley_codes, (ctx.ctx, ctx.threshold, Finch.get_structure(prgm)), [])
valid_match = nothing
for (stats_dict, f_code) in stats_list
if all(issimilar(cur_stats, stats_dict[cur_expr], ctx.threshold) for (cur_expr, cur_stats) in cur_stats_dict)
valid_match = f_code
break
end
end
if isnothing(valid_match)
thunk = Finch.logic_executor_code(ctx.ctx, prgm)
valid_match = (eval(thunk), thunk)
push!(stats_list, (cur_stats_dict, valid_match))
end
(f, code) = valid_match
if ctx.verbose
println("Executing:")
display(code)
end
return Base.invokelatest(f, prgm)
end

"""
AdaptiveExecutorCode(ctx)

Return the code that would normally be used by the AdaptiveExecutor to run a program.
"""
struct AdaptiveExecutorCode
ctx
end

function (ctx::AdaptiveExecutorCode)(prgm)
return Finch.logic_executor_code(ctx.ctx, prgm)

Check warning on line 122 in src/Galley/FinchCompat/executor.jl

View check run for this annotation

Codecov / codecov/patch

src/Galley/FinchCompat/executor.jl#L121-L122

Added lines #L121 - L122 were not covered by tests
end

"""
galley_scheduler(verbose = false, estimator=DCStats)

The galley scheduler uses the sparsity patterns of the inputs to optimize the computation.
The first set of inputs given to galley is used to optimize, and the `estimator` is used to
estimate the sparsity of intermediate computations during optimization.
"""
galley_scheduler(; verbose = false, estimator=DCStats) = Finch.LogicExecutor(GalleyOptimizer(verbose=verbose, estimator=estimator); verbose=verbose)
galley_scheduler(;threshold=2, verbose=false) = AdaptiveExecutor(GalleyOptimizer(;verbose=verbose); threshold=threshold, verbose=verbose)

2 changes: 1 addition & 1 deletion src/Galley/Galley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export PlanNode, Value, Index, Alias, Input, MapJoin, Aggregate, Materialize, Qu
export Scalar, Σ, Mat, Agg
export DCStats, NaiveStats, TensorDef, DC, insert_statistics
export naive, greedy, pruned, exact
export GalleyOptimizer, galley_scheduler
export GalleyOptimizer, AdaptiveExecutor, AdaptiveExecutorCode, galley_scheduler

IndexExpr = Symbol
TensorId = Symbol
Expand Down
4 changes: 3 additions & 1 deletion src/Galley/TensorStats/propagate-stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,11 @@
return merge_tensor_stats_join(op, new_def, join_like_args...)
elseif length(join_like_args) == 0
return merge_tensor_stats_union(op, new_def, union_like_args...)
else
elseif union([get_index_set(stats) for stats in join_like_args]...) == get_index_set(new_def)
# Currently we glean no information from non-join-like args
return merge_tensor_stats_join(op, new_def, join_like_args...)
else
return merge_tensor_stats_union(op, new_def, join_like_args..., union_like_args...)

Check warning on line 92 in src/Galley/TensorStats/propagate-stats.jl

View check run for this annotation

Codecov / codecov/patch

src/Galley/TensorStats/propagate-stats.jl#L92

Added line #L92 was not covered by tests
end
end

Expand Down
Loading
Loading