Skip to content

Commit

Permalink
Extend data access API (#169)
Browse files Browse the repository at this point in the history
* add getter for number of dofs

* change trixi_load_cell_average

add parameter index and only get averaged of the variable at position index

* fix: ensure Int32

* adapt reference value

* add getter for all dofs values

* adapt next value

* add trixi_ndofs_element

* add tests

* add missing parts in tests

* fix tests

* deallocate first

* add trixi_load_prim to Fortran API

* reference value

* get doxygen right

* make everything more consistent!

* add functions to get quadrature information

* update CI badge URL

* add get_t8code_forest to Fortran interface

* format

* remove duplicate

* make gcc 14 happy

* Apply suggestions from code review

Co-authored-by: Michael Schlottke-Lakemper <[email protected]>

---------

Co-authored-by: Michael Schlottke-Lakemper <[email protected]>
Co-authored-by: Benedict <[email protected]>
  • Loading branch information
3 people authored Nov 18, 2024
1 parent 3a813d2 commit 0e328c1
Show file tree
Hide file tree
Showing 14 changed files with 763 additions and 156 deletions.
39 changes: 30 additions & 9 deletions LibTrixi.jl/src/LibTrixi.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
module LibTrixi

using OrdinaryDiffEq: OrdinaryDiffEq, step!, check_error, DiscreteCallback
using Trixi: Trixi, summary_callback, mesh_equations_solver_cache, nelements,
nelementsglobal, nvariables, nnodes, wrap_array, eachelement, cons2prim,
get_node_vars, eachnode
using Trixi: Trixi, summary_callback, mesh_equations_solver_cache, ndims, nelements,
nelementsglobal, ndofs, ndofsglobal, nvariables, nnodes, wrap_array,
eachelement, cons2prim, get_node_vars, eachnode
using MPI: MPI, run_init_hooks, set_default_error_handler_return
using Pkg

Expand All @@ -28,15 +28,36 @@ export trixi_ndims,
export trixi_nelements,
trixi_nelements_cfptr,
trixi_nelements_jl
export trixi_nelements_global,
trixi_nelements_global_cfptr,
trixi_nelements_global_jl
export trixi_nelementsglobal,
trixi_nelementsglobal_cfptr,
trixi_nelementsglobal_jl
export trixi_ndofs,
trixi_ndofs_cfptr,
trixi_ndofs_jl
export trixi_ndofsglobal,
trixi_ndofsglobal_cfptr,
trixi_ndofsglobal_jl
export trixi_ndofselement,
trixi_ndofselement_cfptr,
trixi_ndofselement_jl
export trixi_nvariables,
trixi_nvariables_cfptr,
trixi_nvariables_jl
export trixi_load_cell_averages,
trixi_load_cell_averages_cfptr,
trixi_load_cell_averages_jl
export trixi_nnodes,
trixi_nnodes_cfptr,
trixi_nnodes_jl
export trixi_load_node_reference_coordinates,
trixi_load_node_reference_coordinates_cfptr,
trixi_load_node_reference_coordinates_jl
export trixi_load_node_weights,
trixi_load_node_weights_cfptr,
trixi_load_node_weights_jl
export trixi_load_primitive_vars,
trixi_load_primitive_vars_cfptr,
trixi_load_primitive_vars_jl
export trixi_load_element_averaged_primitive_vars,
trixi_load_element_averaged_primitive_vars_cfptr,
trixi_load_element_averaged_primitive_vars_jl
export trixi_version_library,
trixi_version_library_cfptr,
trixi_version_library_jl
Expand Down
175 changes: 155 additions & 20 deletions LibTrixi.jl/src/api_c.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ trixi_ndims_cfptr() = @cfunction(trixi_ndims, Cint, (Cint,))
"""
trixi_nelements(simstate_handle::Cint)::Cint
Return number of local elements (cells).
Return number of elements local to the MPI rank.
"""
function trixi_nelements end

Expand All @@ -298,18 +298,63 @@ trixi_nelements_cfptr() = @cfunction(trixi_nelements, Cint, (Cint,))


"""
trixi_nelements_global(simstate_handle::Cint)::Cint
trixi_nelementsglobal(simstate_handle::Cint)::Cint
Return number of global elements (cells).
Return global number of elements on all MPI ranks.
"""
function trixi_nelements_global end
function trixi_nelementsglobal end

Base.@ccallable function trixi_nelements_global(simstate_handle::Cint)::Cint
Base.@ccallable function trixi_nelementsglobal(simstate_handle::Cint)::Cint
simstate = load_simstate(simstate_handle)
return trixi_nelements_global_jl(simstate)
return trixi_nelementsglobal_jl(simstate)
end

trixi_nelements_global_cfptr() = @cfunction(trixi_nelements_global, Cint, (Cint,))
trixi_nelementsglobal_cfptr() = @cfunction(trixi_nelementsglobal, Cint, (Cint,))


"""
trixi_ndofs(simstate_handle::Cint)::Cint
Return number of degrees of freedom (all quadrature points on all elements of current rank).
"""
function trixi_ndofs end

Base.@ccallable function trixi_ndofs(simstate_handle::Cint)::Cint
simstate = load_simstate(simstate_handle)
return trixi_ndofs_jl(simstate)
end

trixi_ndofs_cfptr() = @cfunction(trixi_ndofs, Cint, (Cint,))


"""
trixi_ndofsglobal(simstate_handle::Cint)::Cint
Return global number of degrees of freedom (all quadrature points on all elements on all ranks).
"""
function trixi_ndofsglobal end

Base.@ccallable function trixi_ndofsglobal(simstate_handle::Cint)::Cint
simstate = load_simstate(simstate_handle)
return trixi_ndofsglobal_jl(simstate)
end

trixi_ndofsglobal_cfptr() = @cfunction(trixi_ndofsglobal, Cint, (Cint,))


"""
trixi_ndofselement(simstate_handle::Cint)::Cint
Return number of degrees of freedom per element.
"""
function trixi_ndofselement end

Base.@ccallable function trixi_ndofselement(simstate_handle::Cint)::Cint
simstate = load_simstate(simstate_handle)
return trixi_ndofselement_jl(simstate)
end

trixi_ndofselement_cfptr() = @cfunction(trixi_ndofselement, Cint, (Cint,))


"""
Expand All @@ -328,32 +373,122 @@ trixi_nvariables_cfptr() = @cfunction(trixi_nvariables, Cint, (Cint,))


"""
trixi_load_cell_averages(data::Ptr{Cdouble}, simstate_handle::Cint)::Cvoid
trixi_nnodes(simstate_handle::Cint)::Cint
Return number of quadrature nodes per dimension.
"""
function trixi_nnodes end

Base.@ccallable function trixi_nnodes(simstate_handle::Cint)::Cint
simstate = load_simstate(simstate_handle)
return trixi_nnodes_jl(simstate)
end

trixi_nnodes_cfptr() = @cfunction(trixi_nnodes, Cint, (Cint,))


"""
trixi_load_node_reference_coordinates(simstate_handle::Cint, data::Ptr{Cdouble})::Cvoid
Get reference coordinates of 1D quadrature nodes.
"""
function trixi_load_node_reference_coordinates end

Base.@ccallable function trixi_load_node_reference_coordinates(simstate_handle::Cint,
data::Ptr{Cdouble})::Cvoid
simstate = load_simstate(simstate_handle)

# convert C to Julia array
size = trixi_nnodes_jl(simstate)
data_jl = unsafe_wrap(Array, data, size)

trixi_load_node_reference_coordinates_jl(simstate, data_jl)
return nothing
end

trixi_load_node_reference_coordinates_cfptr() =
@cfunction(trixi_load_node_reference_coordinates, Cvoid, (Cint, Ptr{Cdouble}))


"""
trixi_load_node_weights(simstate_handle::Cint, data::Ptr{Cdouble})::Cvoid
Get weights of 1D quadrature nodes.
"""
function trixi_load_node_weights end

Base.@ccallable function trixi_load_node_weights(simstate_handle::Cint,
data::Ptr{Cdouble})::Cvoid
simstate = load_simstate(simstate_handle)

# convert C to Julia array
size = trixi_nnodes_jl(simstate)
data_jl = unsafe_wrap(Array, data, size)

return trixi_load_node_weights_jl(simstate, data_jl)
end

trixi_load_node_weights_cfptr() =
@cfunction(trixi_load_node_weights, Cvoid, (Cint, Ptr{Cdouble}))


"""
trixi_load_primitive_vars(simstate_handle::Cint, variable_id::Cint,
data::Ptr{Cdouble})::Cvoid
Load primitive variable.
The values for the primitive variable at position `variable_id` at every degree of freedom
are stored in the given array `data`.
The given array has to be of correct size (ndofs) and memory has to be allocated beforehand.
"""
function trixi_load_primitive_vars end

Base.@ccallable function trixi_load_primitive_vars(simstate_handle::Cint, variable_id::Cint,
data::Ptr{Cdouble})::Cvoid
simstate = load_simstate(simstate_handle)

# convert C to Julia array
size = trixi_ndofs_jl(simstate)
data_jl = unsafe_wrap(Array, data, size)

trixi_load_primitive_vars_jl(simstate, variable_id, data_jl)
return nothing
end

trixi_load_primitive_vars_cfptr() =
@cfunction(trixi_load_primitive_vars, Cvoid, (Cint, Cint, Ptr{Cdouble}))


"""
trixi_load_element_averaged_primitive_vars(simstate_handle::Cint, variable_id::Cint,
data::Ptr{Cdouble})::Cvoid
Return cell averaged solution state.
Load element averages for primitive variable.
Cell averaged values for each cell and each primitive variable are stored in a contiguous
array, where cell values for the first variable appear first and values for the other
variables subsequently (structure-of-arrays layout).
Element averaged values for the primitive variable at position `variable_id` for each
element are stored in the given array `data`.
The given array has to be of correct size and memory has to be allocated beforehand.
The given array has to be of correct size (nelements) and memory has to be allocated
beforehand.
"""
function trixi_load_cell_averages end
function trixi_load_element_averaged_primitive_vars end

Base.@ccallable function trixi_load_cell_averages(data::Ptr{Cdouble},
simstate_handle::Cint)::Cvoid
Base.@ccallable function trixi_load_element_averaged_primitive_vars(simstate_handle::Cint,
variable_id::Cint, data::Ptr{Cdouble})::Cvoid
simstate = load_simstate(simstate_handle)

# convert C to Julia array
size = trixi_nvariables_jl(simstate) * trixi_nelements_jl(simstate)
size = trixi_nelements_jl(simstate)
data_jl = unsafe_wrap(Array, data, size)

trixi_load_cell_averages_jl(data_jl, simstate)
trixi_load_element_averaged_primitive_vars_jl(simstate, variable_id, data_jl)
return nothing
end

trixi_load_cell_averages_cfptr() =
@cfunction(trixi_load_cell_averages, Cvoid, (Ptr{Cdouble}, Cint,))
trixi_load_element_averaged_primitive_vars_cfptr() =
@cfunction(trixi_load_element_averaged_primitive_vars, Cvoid, (Cint, Cint, Ptr{Cdouble}))


"""
Expand Down
86 changes: 72 additions & 14 deletions LibTrixi.jl/src/api_jl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,22 +86,85 @@ function trixi_nelements_jl(simstate)
end


function trixi_nelements_global_jl(simstate)
function trixi_nelementsglobal_jl(simstate)
mesh, _, solver, cache = mesh_equations_solver_cache(simstate.semi)
return nelementsglobal(mesh, solver, cache)
end


function trixi_ndofs_jl(simstate)
mesh, _, solver, cache = mesh_equations_solver_cache(simstate.semi)
return ndofs(mesh, solver, cache)
end


function trixi_ndofsglobal_jl(simstate)
mesh, _, solver, cache = mesh_equations_solver_cache(simstate.semi)
return ndofsglobal(mesh, solver, cache)
end


function trixi_ndofselement_jl(simstate)
mesh, _, solver, _ = mesh_equations_solver_cache(simstate.semi)
return nnodes(solver)^ndims(mesh)
end


function trixi_nvariables_jl(simstate)
_, equations, _, _ = mesh_equations_solver_cache(simstate.semi)
return nvariables(equations)
end


function trixi_load_cell_averages_jl(data, simstate)
function trixi_nnodes_jl(simstate)
_, _, solver, _ = mesh_equations_solver_cache(simstate.semi)
return nnodes(solver)
end


function trixi_load_node_reference_coordinates_jl(simstate, data)
_, _, solver, _ = mesh_equations_solver_cache(simstate.semi)
for i in eachnode(solver)
data[i] = solver.basis.nodes[i]
end
end


function trixi_load_node_weights_jl(simstate, data)
_, _, solver, _ = mesh_equations_solver_cache(simstate.semi)
for i in eachnode(solver)
data[i] = solver.basis.weights[i]
end
end


function trixi_load_primitive_vars_jl(simstate, variable_id, data)
mesh, equations, solver, cache = mesh_equations_solver_cache(simstate.semi)
n_nodes_per_dim = nnodes(solver)
n_dims = ndims(mesh)
n_nodes = n_nodes_per_dim^n_dims

u_ode = simstate.integrator.u
u = wrap_array(u_ode, mesh, equations, solver, cache)

# all permutations of nodes indices for arbitrary dimension
node_cis = CartesianIndices(ntuple(i -> n_nodes_per_dim, n_dims))
node_lis = LinearIndices(node_cis)

for element in eachelement(solver, cache)
for node_ci in node_cis
node_vars = get_node_vars(u, equations, solver, node_ci, element)
node_index = (element-1) * n_nodes + node_lis[node_ci]
data[node_index] = cons2prim(node_vars, equations)[variable_id]
end
end

return nothing
end


function trixi_load_element_averaged_primitive_vars_jl(simstate, variable_id, data)
mesh, equations, solver, cache = mesh_equations_solver_cache(simstate.semi)
n_elements = nelements(solver, cache)
n_variables = nvariables(equations)
n_nodes = nnodes(solver)
n_dims = ndims(mesh)

Expand All @@ -111,15 +174,13 @@ function trixi_load_cell_averages_jl(data, simstate)
# all permutations of nodes indices for arbitrary dimension
node_cis = CartesianIndices(ntuple(i -> n_nodes, n_dims))

# temporary storage for mean value on current element for all variables
u_mean = get_node_vars(u, equations, solver, node_cis[1], 1)

for element in eachelement(solver, cache)

# compute mean value using nodal dg values and quadrature
u_mean = zero(u_mean)
u_mean = zero(eltype(u))
for node_ci in node_cis
u_node_prim = cons2prim(get_node_vars(u, equations, solver, node_ci, element), equations)
u_node_prim = cons2prim(get_node_vars(u, equations, solver, node_ci, element),
equations)[variable_id]
weight = 1.
for node_index in Tuple(node_ci)
weight *= solver.basis.weights[node_index]
Expand All @@ -130,11 +191,8 @@ function trixi_load_cell_averages_jl(data, simstate)
# normalize to unit element
u_mean = u_mean / 2^n_dims

# copy to provided array
# all element values for first variable, then for second, ...
for ivar = 0:n_variables-1
data[element + ivar * n_elements] = u_mean[ivar+1]
end
# write to provided array
data[element] = u_mean
end

return nothing
Expand Down
Loading

0 comments on commit 0e328c1

Please sign in to comment.