Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into new-rocm-container
Browse files Browse the repository at this point in the history
  • Loading branch information
fthaler committed Oct 29, 2024
2 parents 2f61ca3 + 32daaa5 commit da6aab0
Show file tree
Hide file tree
Showing 15 changed files with 87,177 additions and 86,980 deletions.
19 changes: 19 additions & 0 deletions include/gridtools/fn/backend/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,29 @@ namespace gridtools::fn::backend {
meta::rename<tuple, Dims>());
}

template <class Dims, class Sizes, class UnrollFactors>
constexpr GT_FUNCTION auto make_unrolled_loops(Sizes const &sizes, UnrollFactors) {
return tuple_util::host_device::fold(
[&](auto outer, auto dim) {
using unroll_factor = element_at<decltype(dim), UnrollFactors>;
return [outer = std::move(outer),
inner = sid::make_unrolled_loop<decltype(dim), unroll_factor::value>(
host_device::at_key<decltype(dim)>(sizes))](
auto &&...args) { return outer(inner(std::forward<decltype(args)>(args)...)); };
},
host_device::identity(),
meta::rename<tuple, Dims>());
}

template <class Sizes>
constexpr GT_FUNCTION auto make_loops(Sizes const &sizes) {
return make_loops<get_keys<Sizes>>(sizes);
}

template <class Sizes, class UnrollFactors>
constexpr GT_FUNCTION auto make_unrolled_loops(Sizes const &sizes, UnrollFactors unroll_factors) {
return make_unrolled_loops<get_keys<Sizes>>(sizes, unroll_factors);
}
} // namespace common

template <class T>
Expand Down
223 changes: 147 additions & 76 deletions include/gridtools/fn/backend/gpu.hpp

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions include/gridtools/fn/common_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
*/
#pragma once

#include <type_traits>

#include "../common/tuple.hpp"
#include "../common/tuple_util.hpp"
#include "../stencil/positional.hpp"
#include "./backend/common.hpp"

namespace gridtools::fn {
Expand All @@ -30,4 +33,9 @@ namespace gridtools::fn {
return {std::forward<Args>(args)...};
}

template <class D>
constexpr auto index(D) {
return gridtools::stencil::positional<D>();
}

} // namespace gridtools::fn
8 changes: 3 additions & 5 deletions include/gridtools/fn/unstructured.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "../common/ldg_ptr.hpp"
#include "../meta/logical.hpp"
#include "../sid/concept.hpp"
#include "../stencil/positional.hpp"
#include "./common_interface.hpp"
#include "./executor.hpp"
#include "./neighbor_table.hpp"
Expand All @@ -28,7 +27,6 @@ namespace gridtools::fn {
} // namespace unstructured::dim

namespace unstructured_impl_ {
using gridtools::stencil::positional;
namespace dim = unstructured::dim;

template <class Tables, class Sizes>
Expand Down Expand Up @@ -145,13 +143,13 @@ namespace gridtools::fn {
Domain m_domain;
TmpAllocator m_allocator;

static constexpr auto index = positional<dim::horizontal>();
static constexpr auto horizontal_index = index(dim::horizontal{});

auto stencil_executor() const {
return [&] {
return make_stencil_executor<1>(
m_backend, m_domain.m_sizes, m_domain.m_offsets, make_iterator(m_domain.without_offsets()))
.arg(index); // the horizontal index is passed as the first argument
.arg(horizontal_index); // the horizontal index is passed as the first argument
};
}

Expand All @@ -160,7 +158,7 @@ namespace gridtools::fn {
return [&] {
return make_vertical_executor<Vertical, 1>(
m_backend, m_domain.m_sizes, m_domain.m_offsets, make_iterator(m_domain.without_offsets()))
.arg(index); // the horizontal index is passed as the first argument
.arg(horizontal_index); // the horizontal index is passed as the first argument
};
}
};
Expand Down
39 changes: 39 additions & 0 deletions include/gridtools/sid/loop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <utility>

#include "../common/defs.hpp"
#include "../common/for_each.hpp"
#include "../common/functional.hpp"
#include "../common/host_device.hpp"
#include "../common/integral_constant.hpp"
Expand Down Expand Up @@ -637,6 +638,44 @@ namespace gridtools {
return {};
}

template <class Key,
int UnrollFactor,
class NumSteps,
class Step = integral_constant<int, 1>,
std::enable_if_t<(UnrollFactor > 1), int> = 0>
constexpr GT_FUNCTION auto make_unrolled_loop(NumSteps num_steps, Step step = {}) {
using u = integral_constant<int, UnrollFactor>;
return [step,
unrolled = make_loop<Key>(num_steps / u(), step * u()),
epilogue = make_loop<Key>(num_steps % u(), step),
epilogue_start = step * ((num_steps / u()) * u())](auto &&fun) {
return [unrolled =
unrolled([step, fun = std::forward<decltype(fun)>(fun)](auto &&ptr, auto const strides) {
::gridtools::host_device::for_each<meta::make_indices_c<UnrollFactor>>([&](auto) {
fun(std::forward<decltype(ptr)>(ptr), strides);
shift(std::forward<decltype(ptr)>(ptr), get_stride<Key>(strides), step);
});
shift(std::forward<decltype(ptr)>(ptr), get_stride<Key>(strides), -step * u());
}),
epilogue = epilogue(std::forward<decltype(fun)>(fun)),
epilogue_start](auto &&ptr, auto const &strides) {
unrolled(std::forward<decltype(ptr)>(ptr), strides);
shift(std::forward<decltype(ptr)>(ptr), get_stride<Key>(strides), epilogue_start);
epilogue(std::forward<decltype(ptr)>(ptr), strides);
shift(std::forward<decltype(ptr)>(ptr), get_stride<Key>(strides), -epilogue_start);
};
};
}

template <class Key,
int UnrollFactor,
class NumSteps,
class Step = integral_constant<int, 1>,
std::enable_if_t<(UnrollFactor == 1), int> = 0>
constexpr GT_FUNCTION auto make_unrolled_loop(NumSteps num_steps, Step step = {}) {
return make_loop<Key>(num_steps, step);
}

/**
* A helper that allows to use `SID`s with C++11 range based loop
*
Expand Down
Loading

0 comments on commit da6aab0

Please sign in to comment.