Skip to content

Commit

Permalink
Do not initialize the pinned mdarray at construction time
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Nov 1, 2024
1 parent 41ddbfe commit ad95826
Showing 1 changed file with 24 additions and 43 deletions.
67 changes: 24 additions & 43 deletions cpp/include/raft/core/pinned_container_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

#include <cstddef>
#ifndef RAFT_DISABLE_CUDA
#include <thrust/host_vector.h>
#include <thrust/mr/allocator.h>
#include <thrust/system/cuda/memory_resource.h>
#include <cuda_runtime.h>

#include <type_traits>
#else
#include <raft/core/detail/fail_container_policy.hpp>
#endif
Expand All @@ -30,20 +30,16 @@ namespace raft {
#ifndef RAFT_DISABLE_CUDA

/**
* @brief A thin wrapper over thrust::host_vector for implementing the pinned mdarray container
* policy.
* @brief A thin wrapper over cudaMallocHost/cudaFreeHost for implementing the pinned mdarray
* container policy.
*
*/
template <typename T>
struct pinned_container {
using value_type = T;
using allocator_type =
thrust::mr::stateless_resource_allocator<value_type,
thrust::cuda::universal_host_pinned_memory_resource>;
using value_type = std::remove_cv_t<T>;

private:
using underlying_container_type = thrust::host_vector<value_type, allocator_type>;
underlying_container_type data_;
value_type* data_ = nullptr;

public:
using size_type = std::size_t;
Expand All @@ -57,21 +53,20 @@ struct pinned_container {
using iterator = pointer;
using const_iterator = const_pointer;

~pinned_container() = default;
pinned_container(pinned_container&&) noexcept = default;
pinned_container(pinned_container const& that) : data_{that.data_} {}

auto operator=(pinned_container<T> const& that) -> pinned_container<T>&
explicit pinned_container(std::size_t size)
{
RAFT_CUDA_TRY(cudaMallocHost(&data_, size * sizeof(value_type)));
}
~pinned_container() noexcept
{
data_ = underlying_container_type{that.data_};
return *this;
if (data_ != nullptr) { RAFT_CUDA_TRY_NO_THROW(cudaFreeHost(data_)); }
}
auto operator=(pinned_container<T>&& that) noexcept -> pinned_container<T>& = default;

/**
* @brief Ctor that accepts a size.
*/
explicit pinned_container(std::size_t size, allocator_type const& alloc) : data_{size, alloc} {}
pinned_container(pinned_container&&) = default;
pinned_container& operator=(pinned_container&&) = default;
pinned_container(pinned_container const&) = delete; // Copying disallowed: one array one owner
pinned_container& operator=(pinned_container const&) = delete;

/**
* @brief Index operator that returns a reference to the actual data.
*/
Expand All @@ -84,15 +79,13 @@ struct pinned_container {
* @brief Index operator that returns a reference to the actual data.
*/
template <typename Index>
auto operator[](Index i) const noexcept
auto operator[](Index i) const noexcept -> const_reference
{
return data_[i];
}

void resize(size_type size) { data_.resize(size, data_.stream()); }

[[nodiscard]] auto data() noexcept -> pointer { return data_.data().get(); }
[[nodiscard]] auto data() const noexcept -> const_pointer { return data_.data().get(); }
[[nodiscard]] auto data() noexcept -> pointer { return data_; }
[[nodiscard]] auto data() const noexcept -> const_pointer { return data_; }
};

/**
Expand All @@ -102,39 +95,27 @@ template <typename ElementType>
struct pinned_vector_policy {
using element_type = ElementType;
using container_type = pinned_container<element_type>;
using allocator_type = typename container_type::allocator_type;
using pointer = typename container_type::pointer;
using const_pointer = typename container_type::const_pointer;
using reference = typename container_type::reference;
using const_reference = typename container_type::const_reference;
using accessor_policy = std::experimental::default_accessor<element_type>;
using const_accessor_policy = std::experimental::default_accessor<element_type const>;

auto create(raft::resources const&, size_t n) -> container_type
{
return container_type(n, allocator_);
}

constexpr pinned_vector_policy() noexcept(std::is_nothrow_default_constructible_v<ElementType>)
: allocator_{}
{
}
auto create(raft::resources const&, size_t n) -> container_type { return container_type(n); }

[[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference
{
return c[n];
}
[[nodiscard]] constexpr auto access(container_type const& c, size_t n) const noexcept
-> const_reference
[[nodiscard]] constexpr auto access(container_type const& c,
size_t n) const noexcept -> const_reference
{
return c[n];
}

[[nodiscard]] auto make_accessor_policy() noexcept { return accessor_policy{}; }
[[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; }

private:
allocator_type allocator_;
};
#else
template <typename ElementType>
Expand Down

0 comments on commit ad95826

Please sign in to comment.