Skip to content

Commit

Permalink
Adding the possibility of prescribing weights before redistributing
Browse files Browse the repository at this point in the history
  • Loading branch information
amartinhuertas committed Jul 12, 2024
1 parent 743bd10 commit c4424ea
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 17 deletions.
32 changes: 28 additions & 4 deletions src/OctreeDistributedDiscreteModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1599,9 +1599,23 @@ end
# Assumptions. Either:
# A) model.parts MPI tasks are included in parts_redistributed_model MPI tasks; or
# B) model.parts MPI tasks include parts_redistributed_model MPI tasks
const WeightsArrayType=Union{Nothing,MPIArray{<:Vector{<:Integer}}}
function GridapDistributed.redistribute(model::OctreeDistributedDiscreteModel{Dc,Dp},
parts_redistributed_model=model.parts) where {Dc,Dp}
parts_redistributed_model=model.parts;
weights::WeightsArrayType=nothing) where {Dc,Dp}
parts = (parts_redistributed_model === model.parts) ? model.parts : parts_redistributed_model
_weights=nothing
if (weights !== nothing)
Gridap.Helpers.@notimplementedif parts!==model.parts
_weights=map(model.dmodel.models,weights) do lmodel,weights
# The length of the local weights array has to match the number of
# cells in the model. This includes both owned and ghost cells.
# Only the flags for owned cells are actually taken into account.
@assert num_cells(lmodel)==length(weights)
convert(Vector{Cint},weights)
end
end

comm = parts.comm
if (GridapDistributed.i_am_in(model.parts.comm) || GridapDistributed.i_am_in(parts.comm))
if (parts_redistributed_model !== model.parts)
Expand All @@ -1610,7 +1624,7 @@ function GridapDistributed.redistribute(model::OctreeDistributedDiscreteModel{Dc
@assert A || B
end
if (parts_redistributed_model===model.parts || A)
_redistribute_parts_subseteq_parts_redistributed(model,parts_redistributed_model)
_redistribute_parts_subseteq_parts_redistributed(model,parts_redistributed_model,_weights)
else
_redistribute_parts_supset_parts_redistributed(model, parts_redistributed_model)
end
Expand All @@ -1619,7 +1633,9 @@ function GridapDistributed.redistribute(model::OctreeDistributedDiscreteModel{Dc
end
end

function _redistribute_parts_subseteq_parts_redistributed(model::OctreeDistributedDiscreteModel{Dc,Dp}, parts_redistributed_model) where {Dc,Dp}
function _redistribute_parts_subseteq_parts_redistributed(model::OctreeDistributedDiscreteModel{Dc,Dp},
parts_redistributed_model,
_weights::WeightsArrayType) where {Dc,Dp}
parts = (parts_redistributed_model === model.parts) ? model.parts : parts_redistributed_model
if (parts_redistributed_model === model.parts)
ptr_pXest_old = model.ptr_pXest
Expand All @@ -1631,7 +1647,15 @@ function _redistribute_parts_subseteq_parts_redistributed(model::OctreeDistribut
parts.comm)
end
ptr_pXest_new = pXest_copy(model.pXest_type, ptr_pXest_old)
pXest_partition!(model.pXest_type, ptr_pXest_new)
if (_weights !== nothing)
init_fn_callback_c = pXest_reset_callbacks(model.pXest_type)
map(_weights) do _weights
pXest_reset_data!(model.pXest_type, ptr_pXest_new, Cint(sizeof(Cint)), init_fn_callback_c, pointer(_weights))
end
pXest_partition!(model.pXest_type, ptr_pXest_new; weights_set=true)
else
pXest_partition!(model.pXest_type, ptr_pXest_new; weights_set=false)
end

# Compute RedistributeGlue
parts_snd, lids_snd, old2new = pXest_compute_migration_control_data(model.pXest_type,ptr_pXest_old,ptr_pXest_new)
Expand Down
51 changes: 45 additions & 6 deletions src/PXestTypeMethods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,16 +309,31 @@ function pXest_balance!(::P8estType, ptr_pXest; k_2_1_balance=0)
end
end

function pXest_partition!(::P4estType, ptr_pXest)
p4est_partition(ptr_pXest, 0, C_NULL)
function pXest_partition!(pXest_type::P4estType, ptr_pXest; weights_set=false)
if (!weights_set)
p4est_partition(ptr_pXest, 0, C_NULL)
else
wcallback=pXest_weight_callback(pXest_type)
p4est_partition(ptr_pXest, 0, wcallback)
end
end

function pXest_partition!(::P6estType, ptr_pXest)
p6est_partition(ptr_pXest, C_NULL)
function pXest_partition!(pXest_type::P6estType, ptr_pXest; weights_set=false)
if (!weights_set)
p6est_partition(ptr_pXest, C_NULL)
else
wcallback=pXest_weight_callback(pXest_type)
p6est_partition(ptr_pXest, wcallback)
end
end

function pXest_partition!(::P8estType, ptr_pXest)
p8est_partition(ptr_pXest, 0, C_NULL)
function pXest_partition!(pXest_type::P8estType, ptr_pXest; weights_set=false)
if (!weights_set)
p8est_partition(ptr_pXest, 0, C_NULL)
else
wcallback=pXest_weight_callback(pXest_type)
p8est_partition(ptr_pXest, 0, wcallback)
end
end


Expand Down Expand Up @@ -805,6 +820,30 @@ function pXest_refine_callbacks(::P8estType)
refine_callback_c, refine_replace_callback_c
end

function pXest_weight_callback(::P4estType)
function weight_callback(::Ptr{p4est_t},
which_tree::p4est_topidx_t,
quadrant_ptr::Ptr{p4est_quadrant_t})
quadrant = quadrant_ptr[]
return unsafe_wrap(Array, Ptr{Cint}(quadrant.p.user_data), 1)[]
end
@cfunction($weight_callback, Cint, (Ptr{p4est_t}, p4est_topidx_t, Ptr{p4est_quadrant_t}))
end

function pXest_weight_callback(::P6estType)
Gridap.Helpers.@notimplemented
end

function pXest_weight_callback(::P8estType)
function weight_callback(::Ptr{p8est_t},
which_tree::p4est_topidx_t,
quadrant_ptr::Ptr{p8est_quadrant_t})
quadrant = quadrant_ptr[]
return unsafe_wrap(Array, Ptr{Cint}(quadrant.p.user_data), 1)[]
end
@cfunction($weight_callback, Cint, (Ptr{p8est_t}, p4est_topidx_t, Ptr{p8est_quadrant_t}))
end

function _unwrap_ghost_quadrants(::P4estType, pXest_ghost)
Ptr{p4est_quadrant_t}(pXest_ghost.ghosts.array)
end
Expand Down
21 changes: 14 additions & 7 deletions test/PoissonNonConformingOctreeModelsTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,14 @@ module PoissonNonConformingOctreeModelsTests
e = uH - uhH
el2 = sqrt(sum( ( ee )*dΩH ))

fmodel_red, red_glue=GridapDistributed.redistribute(fmodel);
weights=map(ranks,fmodel.dmodel.models) do rank,lmodel
if (rank%2==0)
zeros(Cint,num_cells(lmodel))
else
ones(Cint,num_cells(lmodel))
end
end
fmodel_red, red_glue=GridapDistributed.redistribute(fmodel,weights=weights);
Vhred=FESpace(fmodel_red,reffe,conformity=:H1;dirichlet_tags="boundary")
Uhred=TrialFESpace(Vhred,u)

Expand Down Expand Up @@ -274,12 +281,12 @@ module PoissonNonConformingOctreeModelsTests
#debug_logger = ConsoleLogger(stderr, Logging.Debug)
#global_logger(debug_logger); # Enable the debug logger globally
ranks = distribute(LinearIndices((MPI.Comm_size(MPI.COMM_WORLD),)))
for Dc=3:3, perm=1:4, order=1:4, scalar_or_vector in (:scalar,)
test(ranks,Val{Dc},perm,order,_field_type(Val{Dc}(),scalar_or_vector))
end
for Dc=2:3, perm in (1,2), order in (1,4), scalar_or_vector in (:vector,)
test(ranks,Val{Dc},perm,order,_field_type(Val{Dc}(),scalar_or_vector))
end
# for Dc=3:3, perm=1:4, order=1:4, scalar_or_vector in (:scalar,)
# test(ranks,Val{Dc},perm,order,_field_type(Val{Dc}(),scalar_or_vector))
# end
# for Dc=2:3, perm in (1,2), order in (1,4), scalar_or_vector in (:vector,)
# test(ranks,Val{Dc},perm,order,_field_type(Val{Dc}(),scalar_or_vector))
# end
for order=2:2, scalar_or_vector in (:scalar,:vector)
test_2d(ranks,order,_field_type(Val{2}(),scalar_or_vector), num_amr_steps=5)
test_3d(ranks,order,_field_type(Val{3}(),scalar_or_vector), num_amr_steps=4)
Expand Down

0 comments on commit c4424ea

Please sign in to comment.