diff --git a/spgemm/run_spgemm.jl b/spgemm/run_spgemm.jl index 8b26823c..51e0adaa 100644 --- a/spgemm/run_spgemm.jl +++ b/spgemm/run_spgemm.jl @@ -64,6 +64,8 @@ datasets = Dict( include("spgemm_finch.jl") include("spgemm_taco.jl") +include("spgemm_finch_par.jl") + results = [] @@ -72,6 +74,7 @@ for mtx in datasets[parsed_args["dataset"]] B = A C_ref = nothing for (key, method) in [ + "spgemm_finch_gustavson_parallel" => spgemm_finch_gustavson_parallel, "spgemm_taco_inner" => spgemm_taco_inner, "spgemm_taco_gustavson" => spgemm_taco_gustavson, "spgemm_taco_outer" => spgemm_taco_outer, diff --git a/spgemm/spgemm_finch_par.jl b/spgemm/spgemm_finch_par.jl new file mode 100644 index 00000000..998138c9 --- /dev/null +++ b/spgemm/spgemm_finch_par.jl @@ -0,0 +1,40 @@ +using Finch +using BenchmarkTools +using Base.Threads + + +function spgemm_finch_gustavson_kernel_parallel(A, B) + # @assert Threads.nthreads() >= 2 + z = default(A) * default(B) + false + C = Tensor(Dense(Seperation(SparseList(Element(z))))) + w = moveto(Tensor(Dense(Element(z))), CPULocalMemory(CPU())) + @finch_code begin + C .= 0 + for j=parallel(_) + w .= 0 + for k=_, i=_; w[i] += A[i, k] * B[k, j] end + for i=_; C[i, j] = w[i] end + end + end + @finch begin + C .= 0 + for j=parallel(_) + w .= 0 + for k=_, i=_; w[i] += A[i, k] * B[k, j] end + for i=_; C[i, j] = w[i] end + end + end + return C +end + + +function spgemm_finch_parallel(f, A, B) + _A = Tensor(A) + _B = Tensor(B) + C = Ref{Any}() + time = @belapsed $C[] = $f($_A, $_B) + return (;time = time, C = C[]) +end + + +spgemm_finch_gustavson_parallel(A, B) = spgemm_finch_parallel(spgemm_finch_gustavson_kernel_parallel, A, B)