Skip to content

Commit

Permalink
WIP add task_local for duplication things for multithreading
Browse files Browse the repository at this point in the history
  • Loading branch information
fredrikekre committed Sep 26, 2024
1 parent 0762d01 commit 237c1cb
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 12 deletions.
25 changes: 14 additions & 11 deletions docs/src/literate-howto/threaded_assembly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,12 @@ end
# purpose. Finally, for the assembler we call `start_assemble` to create a new assembler but
# note that we set `fillzero = false` because we don't want to risk that a task that starts
# a bit later will zero out data that another task have already assembled.
function ScratchData(dh::DofHandler, K::SparseMatrixCSC, f::Vector, cellvalues::CellValues)
cell_cache = CellCache(dh)
n = ndofs_per_cell(dh)
Ke = zeros(n, n)
fe = zeros(n)
asm = start_assemble(K, f; fillzero = false)
return ScratchData(cell_cache, copy(cellvalues), Ke, fe, asm)
function Ferrite.task_local(scratch::ScratchData)
ScratchData(
task_local(scratch.cell_cache), task_local(scratch.cellvalues),
task_local(scratch.Ke), task_local(scratch.fe),
task_local(scratch.assembler)
)
end
nothing # hide

Expand Down Expand Up @@ -221,11 +220,15 @@ function assemble_global!(
K::SparseMatrixCSC, f::Vector, dh::DofHandler, colors,
cellvalues_template::CellValues; ntasks = Threads.nthreads()
)
## Zero-out existing data in K and f
_ = start_assemble(K, f)
## Body force and material stiffness
b = Vec{3}((0.0, 0.0, -1.0))
C = create_material_stiffness()
## Scratch data
scratch_template = ScratchData(
CellCache(dh), cellvalues_template,
zeros(ndofs_per_cell(dh), ndofs_per_cell(dh)), zeros(ndofs_per_cell(dh)),
start_assemble(K, f)
)
## Loop over the colors
for color in colors
## Dynamic scheduler spawning `ntasks` tasks where each task will process a chunk of
Expand All @@ -236,7 +239,7 @@ function assemble_global!(
## Tell the @tasks loop to use the scheduler defined above
@set scheduler = scheduler
## Obtain a task local scratch and unpack it
@local scratch = ScratchData(dh, K, f, cellvalues_template)
@local scratch = task_local(scratch_template)
(; cell_cache, cellvalues, Ke, fe, assembler) = scratch
## Reinitialize the cell cache and then the cellvalues
reinit!(cell_cache, cellidx)
Expand All @@ -258,7 +261,7 @@ nothing # hide
# ```julia
# # using TaskLocalValues
# scratches = TaskLocalValue() do
# ScratchData(dh, K, f, cellvalues)
# task_local(scratch_template)
# end
# OhMyThreads.tforeach(color; scheduler) do cellidx
# # Obtain a task local scratch and unpack it
Expand Down
6 changes: 6 additions & 0 deletions src/FEValues/CellValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ end
function Base.copy(cv::CellValues)
return CellValues(copy(cv.fun_values), copy(cv.geo_mapping), copy(cv.qr), _copy_or_nothing(cv.detJdV))
end
function task_local(cv::CellValues)
return CellValues(
task_local(cv.fun_values), task_local(cv.geo_mapping), task_local(cv.qr),
task_local(cv.detJdV)
)
end

# Access geometry values
@propagate_inbounds getngeobasefunctions(cv::CellValues) = getngeobasefunctions(cv.geo_mapping)
Expand Down
8 changes: 8 additions & 0 deletions src/FEValues/FunctionValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ function Base.copy(v::FunctionValues)
d2Ndξ2_copy = _copy_or_nothing(v.d2Ndξ2)
return FunctionValues(copy(v.ip), Nx_copy, Nξ_copy, dNdx_copy, dNdξ_copy, d2Ndx2_copy, d2Ndξ2_copy)
end
function task_local(v::FunctionValues)
= task_local(v.Nξ)
Nx = v.=== v.Nx ?: task_local(v.Nx) # Preserve aliasing
return FunctionValues(
task_local(v.ip), Nx, Nξ, task_local(v.dNdx), task_local(v.dNdξ),
task_local(v.d2Ndx2), task_local(v.d2Ndξ2)
)
end

getnbasefunctions(funvals::FunctionValues) = size(funvals.Nx, 1)
@propagate_inbounds shape_value(funvals::FunctionValues, q_point::Int, base_func::Int) = funvals.Nx[base_func, q_point]
Expand Down
5 changes: 5 additions & 0 deletions src/FEValues/GeometryMapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ end
function Base.copy(v::GeometryMapping)
return GeometryMapping(copy(v.ip), copy(v.M), _copy_or_nothing(v.dMdξ), _copy_or_nothing(v.d2Mdξ2))
end
function task_local(v::GeometryMapping)
return GeometryMapping(
task_local(v.ip), task_local(v.M), task_local(v.dMdξ), task_local(v.d2Mdξ2)
)
end

getngeobasefunctions(geo_mapping::GeometryMapping) = size(geo_mapping.M, 1)
@propagate_inbounds geometric_value(geo_mapping::GeometryMapping, q_point::Int, base_func::Int) = geo_mapping.M[base_func, q_point]
Expand Down
2 changes: 2 additions & 0 deletions src/Ferrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ using .CollectionsOfViews:

include("exports.jl")

# Task based multithreading support
include("multithreading.jl")

"""
AbstractRefShape{refdim}
Expand Down
1 change: 1 addition & 0 deletions src/Quadrature/quadrature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,4 @@ getrefshape(::QuadratureRule{RefShape}) where RefShape = RefShape

# TODO: This is used in copy(::(Cell|Face)Values), but it it useful to get an actual copy?
Base.copy(qr::Union{QuadratureRule, FacetQuadratureRule}) = qr
task_local(qr::Union{QuadratureRule, FacetQuadratureRule}) = copy(qr)
8 changes: 8 additions & 0 deletions src/assembler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ matrix_handle(a::AbstractCSCAssembler) = a.K
matrix_handle(a::SymmetricCSCAssembler) = a.K.data
vector_handle(a::AbstractCSCAssembler) = a.f

function task_local(asm::CSCAssembler)
return CSCAssembler(asm.K, asm.f, task_local(asm.permutation), task_local(asm.sorteddofs))
end
function task_local(asm::SymmetricCSCAssembler)
return SymmetricCSCAssembler(asm.K, asm.f, task_local(asm.permutation), task_local(asm.sorteddofs))
end


"""
start_assemble(K::AbstractSparseMatrixCSC; fillzero::Bool=true) -> CSCAssembler
start_assemble(K::AbstractSparseMatrixCSC, f::Vector; fillzero::Bool=true) -> CSCAssembler
Expand Down
5 changes: 4 additions & 1 deletion src/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,7 @@ export
evaluate_at_points,
PointIterator,
PointLocation,
PointValues
PointValues,

# Misc
task_local
14 changes: 14 additions & 0 deletions src/multithreading.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
function task_local end

"""
task_local(A)
Duplicate `A` for a new task.
"""
task_local(::Any)

# Vector/Matrix (e.g. local matrix and vector)
task_local(A::Array) = copy(A)

# To help with struct fields which are Union{X, Nothing}
task_local(::Nothing) = nothing

0 comments on commit 237c1cb

Please sign in to comment.