diff --git a/docs/src/literate-howto/threaded_assembly.jl b/docs/src/literate-howto/threaded_assembly.jl index 8f47cccc75..8cbcca0c3c 100644 --- a/docs/src/literate-howto/threaded_assembly.jl +++ b/docs/src/literate-howto/threaded_assembly.jl @@ -171,13 +171,12 @@ end # purpose. Finally, for the assembler we call `start_assemble` to create a new assembler but # note that we set `fillzero = false` because we don't want to risk that a task that starts # a bit later will zero out data that another task have already assembled. -function ScratchData(dh::DofHandler, K::SparseMatrixCSC, f::Vector, cellvalues::CellValues) - cell_cache = CellCache(dh) - n = ndofs_per_cell(dh) - Ke = zeros(n, n) - fe = zeros(n) - asm = start_assemble(K, f; fillzero = false) - return ScratchData(cell_cache, copy(cellvalues), Ke, fe, asm) +function Ferrite.task_local(scratch::ScratchData) + ScratchData( + task_local(scratch.cell_cache), task_local(scratch.cellvalues), + task_local(scratch.Ke), task_local(scratch.fe), + task_local(scratch.assembler) + ) end nothing # hide @@ -221,11 +220,15 @@ function assemble_global!( K::SparseMatrixCSC, f::Vector, dh::DofHandler, colors, cellvalues_template::CellValues; ntasks = Threads.nthreads() ) - ## Zero-out existing data in K and f - _ = start_assemble(K, f) ## Body force and material stiffness b = Vec{3}((0.0, 0.0, -1.0)) C = create_material_stiffness() + ## Scratch data + scratch_template = ScratchData( + CellCache(dh), cellvalues_template, + zeros(ndofs_per_cell(dh), ndofs_per_cell(dh)), zeros(ndofs_per_cell(dh)), + start_assemble(K, f) + ) ## Loop over the colors for color in colors ## Dynamic scheduler spawning `ntasks` tasks where each task will process a chunk of @@ -236,7 +239,7 @@ function assemble_global!( ## Tell the @tasks loop to use the scheduler defined above @set scheduler = scheduler ## Obtain a task local scratch and unpack it - @local scratch = ScratchData(dh, K, f, cellvalues_template) + @local scratch = task_local(scratch_template) (; cell_cache, cellvalues, Ke, fe, assembler) = scratch ## Reinitialize the cell cache and then the cellvalues reinit!(cell_cache, cellidx) @@ -258,7 +261,7 @@ nothing # hide # ```julia # # using TaskLocalValues # scratches = TaskLocalValue() do -# ScratchData(dh, K, f, cellvalues) +# task_local(scratch_template) # end # OhMyThreads.tforeach(color; scheduler) do cellidx # # Obtain a task local scratch and unpack it diff --git a/src/FEValues/CellValues.jl b/src/FEValues/CellValues.jl index 0403814b94..7b9516b84d 100644 --- a/src/FEValues/CellValues.jl +++ b/src/FEValues/CellValues.jl @@ -66,6 +66,12 @@ end function Base.copy(cv::CellValues) return CellValues(copy(cv.fun_values), copy(cv.geo_mapping), copy(cv.qr), _copy_or_nothing(cv.detJdV)) end +function task_local(cv::CellValues) + return CellValues( + task_local(cv.fun_values), task_local(cv.geo_mapping), task_local(cv.qr), + task_local(cv.detJdV) + ) +end # Access geometry values @propagate_inbounds getngeobasefunctions(cv::CellValues) = getngeobasefunctions(cv.geo_mapping) diff --git a/src/FEValues/FunctionValues.jl b/src/FEValues/FunctionValues.jl index 67cdf5ad44..46664fe200 100644 --- a/src/FEValues/FunctionValues.jl +++ b/src/FEValues/FunctionValues.jl @@ -107,6 +107,14 @@ function Base.copy(v::FunctionValues) d2Ndξ2_copy = _copy_or_nothing(v.d2Ndξ2) return FunctionValues(copy(v.ip), Nx_copy, Nξ_copy, dNdx_copy, dNdξ_copy, d2Ndx2_copy, d2Ndξ2_copy) end +function task_local(v::FunctionValues) + Nξ = task_local(v.Nξ) + Nx = v.Nξ === v.Nx ? Nξ : task_local(v.Nx) # Preserve aliasing + return FunctionValues( + task_local(v.ip), Nx, Nξ, task_local(v.dNdx), task_local(v.dNdξ), + task_local(v.d2Ndx2), task_local(v.d2Ndξ2) + ) +end getnbasefunctions(funvals::FunctionValues) = size(funvals.Nx, 1) @propagate_inbounds shape_value(funvals::FunctionValues, q_point::Int, base_func::Int) = funvals.Nx[base_func, q_point] diff --git a/src/FEValues/GeometryMapping.jl b/src/FEValues/GeometryMapping.jl index b5cede1a58..c722d6cf80 100644 --- a/src/FEValues/GeometryMapping.jl +++ b/src/FEValues/GeometryMapping.jl @@ -97,6 +97,11 @@ end function Base.copy(v::GeometryMapping) return GeometryMapping(copy(v.ip), copy(v.M), _copy_or_nothing(v.dMdξ), _copy_or_nothing(v.d2Mdξ2)) end +function task_local(v::GeometryMapping) + return GeometryMapping( + task_local(v.ip), task_local(v.M), task_local(v.dMdξ), task_local(v.d2Mdξ2) + ) +end getngeobasefunctions(geo_mapping::GeometryMapping) = size(geo_mapping.M, 1) @propagate_inbounds geometric_value(geo_mapping::GeometryMapping, q_point::Int, base_func::Int) = geo_mapping.M[base_func, q_point] diff --git a/src/Ferrite.jl b/src/Ferrite.jl index 2738904d83..8a378688dc 100644 --- a/src/Ferrite.jl +++ b/src/Ferrite.jl @@ -32,6 +32,8 @@ using .CollectionsOfViews: include("exports.jl") +# Task based multithreading support +include("multithreading.jl") """ AbstractRefShape{refdim} diff --git a/src/Quadrature/quadrature.jl b/src/Quadrature/quadrature.jl index bb32acdb2d..9b725682b8 100644 --- a/src/Quadrature/quadrature.jl +++ b/src/Quadrature/quadrature.jl @@ -313,3 +313,4 @@ getrefshape(::QuadratureRule{RefShape}) where RefShape = RefShape # TODO: This is used in copy(::(Cell|Face)Values), but it it useful to get an actual copy? Base.copy(qr::Union{QuadratureRule, FacetQuadratureRule}) = qr +task_local(qr::Union{QuadratureRule, FacetQuadratureRule}) = copy(qr) diff --git a/src/assembler.jl b/src/assembler.jl index e79521dc4b..21ef875115 100644 --- a/src/assembler.jl +++ b/src/assembler.jl @@ -143,6 +143,14 @@ matrix_handle(a::AbstractCSCAssembler) = a.K matrix_handle(a::SymmetricCSCAssembler) = a.K.data vector_handle(a::AbstractCSCAssembler) = a.f +function task_local(asm::CSCAssembler) + return CSCAssembler(asm.K, asm.f, task_local(asm.permutation), task_local(asm.sorteddofs)) +end +function task_local(asm::SymmetricCSCAssembler) + return SymmetricCSCAssembler(asm.K, asm.f, task_local(asm.permutation), task_local(asm.sorteddofs)) +end + + """ start_assemble(K::AbstractSparseMatrixCSC; fillzero::Bool=true) -> CSCAssembler start_assemble(K::AbstractSparseMatrixCSC, f::Vector; fillzero::Bool=true) -> CSCAssembler diff --git a/src/exports.jl b/src/exports.jl index c49b24eba0..9a237cae6c 100644 --- a/src/exports.jl +++ b/src/exports.jl @@ -184,4 +184,7 @@ export evaluate_at_points, PointIterator, PointLocation, - PointValues + PointValues, + +# Misc + task_local diff --git a/src/multithreading.jl b/src/multithreading.jl new file mode 100644 index 0000000000..6e075b4d8c --- /dev/null +++ b/src/multithreading.jl @@ -0,0 +1,14 @@ +function task_local end + +""" + task_local(A) + +Duplicate `A` for a new task. +""" +task_local(::Any) + +# Vector/Matrix (e.g. local matrix and vector) +task_local(A::Array) = copy(A) + +# To help with struct fields which are Union{X, Nothing} +task_local(::Nothing) = nothing