Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adaptive Particle Refinement: Add refinement callback #644

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/TrixiParticles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ export WeaklyCompressibleSPHSystem, EntropicallyDampedSPHSystem, TotalLagrangian
BoundarySPHSystem, DEMSystem, BoundaryDEMSystem, OpenBoundarySPHSystem, InFlow,
OutFlow
export InfoCallback, SolutionSavingCallback, DensityReinitializationCallback,
PostprocessCallback, StepsizeCallback, UpdateCallback
PostprocessCallback, StepsizeCallback, UpdateCallback, ParticleRefinementCallback
export ContinuityDensity, SummationDensity
export PenaltyForceGanzenmueller, TransportVelocityAdami
export SchoenbergCubicSplineKernel, SchoenbergQuarticSplineKernel,
Expand Down
1 change: 1 addition & 0 deletions src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ include("density_reinit.jl")
include("post_process.jl")
include("stepsize.jl")
include("update.jl")
include("refinement.jl")
120 changes: 120 additions & 0 deletions src/callbacks/refinement.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
struct ParticleRefinementCallback{I}
interval::I
end

function ParticleRefinementCallback(; interval::Integer=-1, dt=0.0)
if dt > 0 && interval !== -1
throw(ArgumentError("setting both `interval` and `dt` is not supported"))
end

# Update in intervals in terms of simulation time
if dt > 0
interval = Float64(dt)

# Update every time step (default)
elseif interval == -1
interval = 1
end

refinement_callback = ParticleRefinementCallback(interval)

if dt > 0
# Add a `tstop` every `dt`, and save the final solution.
return PeriodicCallback(refinement_callback, dt,
initialize=initial_refinement!,
save_positions=(false, false))
else
# The first one is the `condition`, the second the `affect!`
return DiscreteCallback(refinement_callback, refinement_callback,
initialize=initial_refinement!,
save_positions=(false, false))
end
end

# initialize
function initial_refinement!(cb, u, t, integrator)
# The `ParticleRefinementCallback` is either `cb.affect!` (with `DiscreteCallback`)
# or `cb.affect!.affect!` (with `PeriodicCallback`).
# Let recursive dispatch handle this.

initial_refinement!(cb.affect!, u, t, integrator)
end

function initial_refinement!(cb::ParticleRefinementCallback, u, t, integrator)
cb(integrator)
end

# condition
function (refinement_callback::ParticleRefinementCallback)(u, t, integrator)
(; interval) = refinement_callback

return condition_integrator_interval(integrator, interval)
end

# affect
function (refinement_callback::ParticleRefinementCallback)(integrator)
t = integrator.t
semi = integrator.p
v_ode, u_ode = integrator.u.x

# Basically `get_tmp_cache(integrator)` to write into in order to be non-allocating
# https://docs.sciml.ai/DiffEqDocs/stable/basics/integrator/#Caches
v_tmp, u_tmp = integrator.cache.tmp.x

v_tmp .= v_ode
u_tmp .= u_ode

refinement!(semi, v_ode, u_ode, v_tmp, u_tmp, t)

resize!(integrator, (length(v_ode), length(u_ode)))

# Tell OrdinaryDiffEq that u has been modified
u_modified!(integrator, true)

return integrator
end

Base.resize!(a::RecursiveArrayTools.ArrayPartition, sizes::Tuple) = resize!.(a.x, sizes)

function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:ParticleRefinementCallback})
@nospecialize cb # reduce precompilation time
print(io, "ParticleRefinementCallback(interval=", (cb.affect!.interval), ")")
end

function Base.show(io::IO,
cb::DiscreteCallback{<:Any,
<:PeriodicCallbackAffect{<:ParticleRefinementCallback}})
@nospecialize cb # reduce precompilation time
print(io, "ParticleRefinementCallback(dt=", cb.affect!.affect!.interval, ")")
end

function Base.show(io::IO, ::MIME"text/plain",
cb::DiscreteCallback{<:Any, <:ParticleRefinementCallback})
@nospecialize cb # reduce precompilation time

if get(io, :compact, false)
show(io, cb)
else
refinement_cb = cb.affect!
setup = [
"interval" => refinement_cb.interval
]
summary_box(io, "ParticleRefinementCallback", setup)
end
end

function Base.show(io::IO, ::MIME"text/plain",
cb::DiscreteCallback{<:Any,
<:PeriodicCallbackAffect{<:ParticleRefinementCallback}})
@nospecialize cb # reduce precompilation time

if get(io, :compact, false)
show(io, cb)
else
refinement_cb = cb.affect!.affect!
setup = [
"dt" => refinement_cb.interval
]
summary_box(io, "ParticleRefinementCallback", setup)
end
end