Skip to content

Commit

Permalink
Changed the Api to use get() for ptrs and operator[] for Ref (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaAtulTewari authored Apr 2, 2024
1 parent 577b6ed commit b599e06
Show file tree
Hide file tree
Showing 13 changed files with 210 additions and 118 deletions.
45 changes: 37 additions & 8 deletions include/pando-lib-galois/containers/host_indexed_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ class HostIndexedMap {
using iterator = HostIndexedMapIt<T>;
using reverse_iterator = std::reverse_iterator<iterator>;

[[nodiscard]] constexpr std::uint64_t getNumHosts() const noexcept {
[[nodiscard]] static constexpr std::uint64_t getNumHosts() noexcept {
return static_cast<std::uint64_t>(pando::getPlaceDims().node.id);
}

[[nodiscard]] constexpr std::uint64_t getCurrentNode() const noexcept {
[[nodiscard]] constexpr std::uint64_t getCurrentHost() const noexcept {
return static_cast<std::uint64_t>(pando::getCurrentPlace().node.id);
}

std::uint64_t size() {
static constexpr std::uint64_t size() noexcept {
return getNumHosts();
}

Expand All @@ -57,20 +57,49 @@ class HostIndexedMap {
deallocateMemory(m_items, getNumHosts());
}

pando::GlobalRef<T> getLocal() noexcept {
return m_items[getCurrentNode()];
pando::GlobalPtr<T> get(std::uint64_t i) noexcept {
return &m_items[i];
}

pando::GlobalRef<T> get(std::uint64_t i) noexcept {
return m_items[i];
pando::GlobalPtr<const T> get(std::uint64_t i) const noexcept {
return &m_items[i];
}

pando::GlobalPtr<T> getLocal() noexcept {
return &m_items[getCurrentHost()];
}

pando::GlobalPtr<const T> getLocal() const noexcept {
return &m_items[getCurrentHost()];
}

pando::GlobalRef<T> getLocalRef() noexcept {
return m_items[getCurrentHost()];
}

pando::GlobalRef<const T> getLocalRef() const noexcept {
return m_items[getCurrentHost()];
}

pando::GlobalRef<T> operator[](std::uint64_t i) noexcept {
return *this->get(i);
}

pando::GlobalRef<const T> operator[](std::uint64_t i) const noexcept {
return *this->get(i);
}

template <typename Y>
pando::GlobalRef<T> getFromPtr(pando::GlobalPtr<Y> ptr) {
pando::GlobalPtr<T> getFromPtr(pando::GlobalPtr<Y> ptr) {
std::uint64_t i = static_cast<std::uint64_t>(pando::localityOf(ptr).node.id);
return this->get(i);
}

template <typename Y>
pando::GlobalRef<T> getRefFromPtr(pando::GlobalPtr<Y> ptr) {
return *getFromPtr(ptr);
}

iterator begin() noexcept {
return iterator(m_items, 0);
}
Expand Down
42 changes: 33 additions & 9 deletions include/pando-lib-galois/containers/host_local_storage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ class HostLocalStorage {
using iterator = HostLocalStorageIt<T>;
using reverse_iterator = std::reverse_iterator<iterator>;

[[nodiscard]] constexpr std::uint64_t getNumHosts() const noexcept {
[[nodiscard]] static constexpr std::uint64_t getNumHosts() noexcept {
return static_cast<std::uint64_t>(pando::getPlaceDims().node.id);
}

[[nodiscard]] constexpr std::uint64_t getCurrentNode() const noexcept {
[[nodiscard]] static constexpr std::uint64_t getCurrentHost() noexcept {
return static_cast<std::uint64_t>(pando::getCurrentPlace().node.id);
}

std::uint64_t size() {
static constexpr std::uint64_t size() noexcept {
return getNumHosts();
}

Expand All @@ -90,12 +90,36 @@ class HostLocalStorage {
HostLocalStorageHeap::deallocate(m_items);
}

pando::GlobalRef<T> getLocal() noexcept {
pando::GlobalPtr<T> getLocal() noexcept {
return m_items.getPointer();
}

pando::GlobalPtr<const T> getLocal() const noexcept {
return m_items.getPointer();
}

pando::GlobalRef<T> getLocalRef() noexcept {
return *m_items.getPointer();
}

pando::GlobalRef<T> get(std::uint64_t i) noexcept {
return *m_items.getPointerAt(pando::NodeIndex(static_cast<std::int16_t>(i)));
pando::GlobalRef<const T> getLocalRef() const noexcept {
return *m_items.getPointer();
}

pando::GlobalPtr<T> get(std::uint64_t i) noexcept {
return m_items.getPointerAt(pando::NodeIndex(static_cast<std::int16_t>(i)));
}

pando::GlobalPtr<const T> get(std::uint64_t i) const noexcept {
return m_items.getPointerAt(pando::NodeIndex(static_cast<std::int16_t>(i)));
}

pando::GlobalRef<T> operator[](std::uint64_t i) noexcept {
return *this->get(i);
}

pando::GlobalRef<const T> operator[](std::uint64_t i) const noexcept {
return *this->get(i);
}

template <typename Y>
Expand Down Expand Up @@ -180,15 +204,15 @@ class HostLocalStorageIt {
constexpr HostLocalStorageIt& operator=(HostLocalStorageIt&&) noexcept = default;

reference operator*() const noexcept {
return m_curr.get(m_loc);
return m_curr[m_loc];
}

reference operator*() noexcept {
return m_curr.get(m_loc);
return m_curr[m_loc];
}

pointer operator->() {
return &m_curr.get(m_loc);
return m_curr.get(m_loc);
}

HostLocalStorageIt& operator++() {
Expand Down
13 changes: 7 additions & 6 deletions include/pando-lib-galois/containers/per_thread.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class PerThreadVector {
// Initialize the per host vectors
for (std::int16_t i = 0; i < static_cast<std::int16_t>(lift(flat, getNumHosts)); i++) {
auto place = pando::Place{pando::NodeIndex{i}, pando::anyPod, pando::anyCore};
auto ref = fmap(flat, get, i);
auto ref = fmap(flat, operator[], i);
std::uint64_t start =
(i == 0) ? 0 : m_indices[static_cast<std::uint64_t>(i) * cores * threads - 1];
std::uint64_t end = m_indices[static_cast<std::uint64_t>(i + 1) * cores * threads - 1];
Expand All @@ -322,10 +322,10 @@ class PerThreadVector {
std::uint64_t start = (host == 0) ? 0 : assign.data.m_indices[index];
std::uint64_t curr = (i == 0) ? 0 : assign.data.m_indices[i - 1];

auto ref = assign.to.get(host);
auto ref = assign.to[host];
pando::Vector<T> localVec = assign.data[i];
for (T elt : localVec) {
fmap(ref, get, curr - start) = elt;
fmap(ref, operator[], curr - start) = elt;
curr++;
}
};
Expand All @@ -343,7 +343,7 @@ class PerThreadVector {
// TODO(AdityaAtulTewari) Make this properly parallel.
// Initialize the per host vectors
for (std::int16_t i = 0; i < static_cast<std::int16_t>(flat.getNumHosts()); i++) {
auto ref = flat.get(i);
auto ref = flat[i];
std::uint64_t start =
(i == 0) ? 0 : m_indices[static_cast<std::uint64_t>(i) * cores * threads - 1];
std::uint64_t end = m_indices[static_cast<std::uint64_t>(i + 1) * cores * threads - 1];
Expand All @@ -364,7 +364,7 @@ class PerThreadVector {
std::uint64_t end =
assign.data.m_indices[(host + 1) * assign.data.cores * assign.data.threads - 1];

auto ref = assign.to.get(host);
auto ref = assign.to[host];
pando::Vector<T> localVec = assign.data[i];
std::uint64_t size = lift(ref, size) - (end - start);
for (T elt : localVec) {
Expand Down Expand Up @@ -459,6 +459,7 @@ class PTVectorIterator {
using value_type = pando::Vector<T>;
using pointer = pando::GlobalPtr<pando::Vector<T>>;
using reference = pando::GlobalRef<pando::Vector<T>>;
using const_reference = pando::GlobalRef<const pando::Vector<T>>;

PTVectorIterator(PerThreadVector<T> arr, std::uint64_t pos) : m_arr(arr), m_pos(pos) {}

Expand Down Expand Up @@ -520,7 +521,7 @@ class PTVectorIterator {
return m_arr.get(m_pos + n);
}

reference operator[](std::uint64_t n) const noexcept {
const_reference operator[](std::uint64_t n) const noexcept {
return m_arr.get(m_pos + n);
}

Expand Down
33 changes: 24 additions & 9 deletions include/pando-lib-galois/containers/pod_local_storage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,19 @@ class PodLocalStorage {
using iterator = PodLocalStorageIt<T>;
using reverse_iterator = std::reverse_iterator<iterator>;

[[nodiscard]] constexpr std::uint64_t getNumPods() const noexcept {
[[nodiscard]] static constexpr std::uint64_t getNumPods() noexcept {
const auto p = pando::getPlaceDims();
return static_cast<std::uint64_t>(p.node.id * p.pod.x * p.pod.y);
}

[[nodiscard]] constexpr std::uint64_t getCurrentPodIdx() const noexcept {
[[nodiscard]] static constexpr std::uint64_t getCurrentPodIdx() noexcept {
const auto dim = pando::getPlaceDims();
const auto cur = pando::getCurrentPlace();
return static_cast<std::uint64_t>(cur.node.id * dim.pod.x * dim.pod.y + cur.pod.x * dim.pod.y +
cur.pod.y);
}

[[nodiscard]] constexpr pando::Place getPlaceFromPodIdx(std::uint64_t idx) const noexcept {
[[nodiscard]] static constexpr pando::Place getPlaceFromPodIdx(std::uint64_t idx) noexcept {
const auto dim = pando::getPlaceDims();
const auto pods = dim.pod.x * dim.pod.y;
const pando::NodeIndex node = pando::NodeIndex(idx / pods);
Expand All @@ -93,7 +93,7 @@ class PodLocalStorage {
return pando::Place(node, pod, pando::anyCore);
}

std::uint64_t size() {
static constexpr std::uint64_t size() noexcept {
return getNumPods();
}

Expand All @@ -110,7 +110,22 @@ class PodLocalStorage {
return *m_items.getPointer();
}

pando::GlobalRef<T> get(std::uint64_t i) noexcept {
pando::GlobalPtr<T> get(std::uint64_t i) noexcept {
auto place = getPlaceFromPodIdx(i);
return *m_items.getPointerAt(place.node, place.pod);
}

pando::GlobalPtr<const T> get(std::uint64_t i) const noexcept {
auto place = getPlaceFromPodIdx(i);
return m_items.getPointerAt(place.node, place.pod);
}

pando::GlobalRef<T> operator[](std::uint64_t i) noexcept {
auto place = getPlaceFromPodIdx(i);
return *m_items.getPointerAt(place.node, place.pod);
}

pando::GlobalRef<const T> operator[](std::uint64_t i) const noexcept {
auto place = getPlaceFromPodIdx(i);
return *m_items.getPointerAt(place.node, place.pod);
}
Expand Down Expand Up @@ -199,15 +214,15 @@ class PodLocalStorageIt {
constexpr PodLocalStorageIt& operator=(PodLocalStorageIt&&) noexcept = default;

reference operator*() const noexcept {
return m_curr.get(m_loc);
return m_curr[m_loc];
}

reference operator*() noexcept {
return m_curr.get(m_loc);
return m_curr[m_loc];
}

pointer operator->() {
return &m_curr.get(m_loc);
return m_curr.get(m_loc);
}

PodLocalStorageIt& operator++() {
Expand Down Expand Up @@ -288,7 +303,7 @@ template <typename T>
const std::uint64_t size = cont.size();
PANDO_CHECK(copy.initialize(size));
for (std::uint64_t i = 0; i < cont.size(); i++) {
copy.get(i) = cont.get(i);
copy[i] = cont[i];
}
refcopy = copy;
}));
Expand Down
Loading

0 comments on commit b599e06

Please sign in to comment.