Skip to content

Commit

Permalink
mpi: support reductions of uint64_t (#372)
Browse files Browse the repository at this point in the history
  • Loading branch information
csegarragonz authored Feb 8, 2024
1 parent 693d0df commit cb9cf15
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/mpi/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,14 @@ void MpiWorld::op_reduce(faabric_op_t* operation,
outBufferCast[slot] =
std::max<int>(outBufferCast[slot], inBufferCast[slot]);
}
} else if (datatype->id == FAABRIC_UINT64) {
auto inBufferCast = reinterpret_cast<uint64_t*>(inBuffer);
auto outBufferCast = reinterpret_cast<uint64_t*>(outBuffer);

for (int slot = 0; slot < count; slot++) {
outBufferCast[slot] =
std::max<uint64_t>(outBufferCast[slot], inBufferCast[slot]);
}
} else if (datatype->id == FAABRIC_DOUBLE) {
auto inBufferCast = reinterpret_cast<double*>(inBuffer);
auto outBufferCast = reinterpret_cast<double*>(outBuffer);
Expand Down Expand Up @@ -1220,6 +1228,14 @@ void MpiWorld::op_reduce(faabric_op_t* operation,
outBufferCast[slot] =
std::min<int>(outBufferCast[slot], inBufferCast[slot]);
}
} else if (datatype->id == FAABRIC_UINT64) {
auto inBufferCast = reinterpret_cast<uint64_t*>(inBuffer);
auto outBufferCast = reinterpret_cast<uint64_t*>(outBuffer);

for (int slot = 0; slot < count; slot++) {
outBufferCast[slot] =
std::min<uint64_t>(outBufferCast[slot], inBufferCast[slot]);
}
} else if (datatype->id == FAABRIC_DOUBLE) {
auto inBufferCast = reinterpret_cast<double*>(inBuffer);
auto outBufferCast = reinterpret_cast<double*>(outBuffer);
Expand All @@ -1246,6 +1262,13 @@ void MpiWorld::op_reduce(faabric_op_t* operation,
auto inBufferCast = reinterpret_cast<int*>(inBuffer);
auto outBufferCast = reinterpret_cast<int*>(outBuffer);

for (int slot = 0; slot < count; slot++) {
outBufferCast[slot] += inBufferCast[slot];
}
} else if (datatype->id == FAABRIC_UINT64) {
auto inBufferCast = reinterpret_cast<uint64_t*>(inBuffer);
auto outBufferCast = reinterpret_cast<uint64_t*>(outBuffer);

for (int slot = 0; slot < count; slot++) {
outBufferCast[slot] += inBufferCast[slot];
}
Expand Down
62 changes: 62 additions & 0 deletions tests/test/mpi/test_mpi_world.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,14 @@ template void doReduceTest<int>(MpiWorld& world,
std::vector<std::vector<int>> rankData,
std::vector<int>& expected);

template void doReduceTest<uint64_t>(
MpiWorld& world,
int root,
MPI_Op op,
MPI_Datatype datatype,
std::vector<std::vector<uint64_t>> rankData,
std::vector<uint64_t>& expected);

template void doReduceTest<double>(MpiWorld& world,
int root,
MPI_Op op,
Expand Down Expand Up @@ -790,6 +798,60 @@ TEST_CASE_METHOD(MpiTestFixture, "Test reduce", "[mpi]")
}
}

SECTION("UINT 64")
{
std::vector<std::vector<uint64_t>> rankData(worldSize,
std::vector<uint64_t>(3));
std::vector<uint64_t> expected(3, 0);

// Prepare rank data
for (int r = 0; r < worldSize; r++) {
rankData[r][0] = r;
rankData[r][1] = r * 10;
rankData[r][2] = r * 100;
}

SECTION("Sum operator")
{
for (int r = 0; r < worldSize; r++) {
expected[0] += rankData[r][0];
expected[1] += rankData[r][1];
expected[2] += rankData[r][2];
}

doReduceTest<uint64_t>(
world, root, MPI_SUM, MPI_UINT64_T, rankData, expected);
}

SECTION("Max operator")
{
expected[0] = (worldSize - 1);
expected[1] = (worldSize - 1) * 10;
expected[2] = (worldSize - 1) * 100;

doReduceTest<uint64_t>(
world, root, MPI_MAX, MPI_UINT64_T, rankData, expected);
}

SECTION("Min operator")
{
// Initialize rankData to non-zero values. This catches faulty
// reduce implementations that always return zero
for (int r = 0; r < worldSize; r++) {
rankData[r][0] = (r + 1);
rankData[r][1] = (r + 1) * 10;
rankData[r][2] = (r + 1) * 100;
}

expected[0] = 1;
expected[1] = 10;
expected[2] = 100;

doReduceTest<uint64_t>(
world, root, MPI_MIN, MPI_UINT64_T, rankData, expected);
}
}

SECTION("Doubles")
{
std::vector<std::vector<double>> rankData(worldSize,
Expand Down

0 comments on commit cb9cf15

Please sign in to comment.