From b5b8123aca752c9e83f38608bbf80501fc09f6c3 Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Thu, 22 Feb 2024 17:51:27 +0000 Subject: [PATCH] mpi: #385 and #379 --- .github/workflows/tests.yml | 2 + include/faabric/mpi/MpiMessage.h | 49 +++++ include/faabric/mpi/MpiMessageBuffer.h | 12 +- include/faabric/mpi/MpiWorld.h | 42 ++-- include/faabric/util/queue.h | 24 ++- src/mpi/CMakeLists.txt | 22 +-- src/mpi/MpiMessage.cpp | 36 ++++ src/mpi/MpiWorld.cpp | 200 +++++++++----------- src/mpi/mpi.proto | 35 ---- tests/dist/mpi/mpi_native.cpp | 8 +- tests/test/mpi/test_mpi_message.cpp | 123 ++++++++++++ tests/test/mpi/test_mpi_message_buffer.cpp | 2 +- tests/test/mpi/test_mpi_world.cpp | 29 ++- tests/test/mpi/test_multiple_mpi_worlds.cpp | 23 --- tests/test/mpi/test_remote_mpi_worlds.cpp | 14 +- 15 files changed, 378 insertions(+), 243 deletions(-) create mode 100644 include/faabric/mpi/MpiMessage.h create mode 100644 src/mpi/MpiMessage.cpp delete mode 100644 src/mpi/mpi.proto create mode 100644 tests/test/mpi/test_mpi_message.cpp diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9eebd4ffb..ceb787294 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -105,6 +105,7 @@ jobs: if: github.event.pull_request.draft == false needs: [conan-cache] runs-on: ubuntu-latest + timeout-minutes: 20 strategy: fail-fast: false matrix: @@ -139,6 +140,7 @@ jobs: if: github.event.pull_request.draft == false needs: [conan-cache] runs-on: ubuntu-latest + timeout-minutes: 20 env: CONAN_CACHE_MOUNT_SOURCE: ~/.conan/ steps: diff --git a/include/faabric/mpi/MpiMessage.h b/include/faabric/mpi/MpiMessage.h new file mode 100644 index 000000000..7c85fde48 --- /dev/null +++ b/include/faabric/mpi/MpiMessage.h @@ -0,0 +1,49 @@ +#pragma once + +#include +#include + +namespace faabric::mpi { + +enum MpiMessageType : int32_t +{ + NORMAL = 0, + BARRIER_JOIN = 1, + BARRIER_DONE = 2, + SCATTER = 3, + GATHER = 4, + ALLGATHER = 5, + REDUCE = 6, + SCAN = 7, + ALLREDUCE = 8, + ALLTOALL = 9, + SENDRECV = 10, + BROADCAST = 11, +}; + +struct MpiMessage +{ + int32_t id; + int32_t worldId; + int32_t sendRank; + int32_t recvRank; + int32_t typeSize; + int32_t count; + MpiMessageType messageType; + void* buffer; +}; + +inline size_t payloadSize(const MpiMessage& msg) +{ + return msg.typeSize * msg.count; +} + +inline size_t msgSize(const MpiMessage& msg) +{ + return sizeof(MpiMessage) + payloadSize(msg); +} + +void serializeMpiMsg(std::vector& buffer, const MpiMessage& msg); + +void parseMpiMsg(const std::vector& bytes, MpiMessage* msg); +} diff --git a/include/faabric/mpi/MpiMessageBuffer.h b/include/faabric/mpi/MpiMessageBuffer.h index 9fc67b644..c36f89887 100644 --- a/include/faabric/mpi/MpiMessageBuffer.h +++ b/include/faabric/mpi/MpiMessageBuffer.h @@ -1,8 +1,9 @@ +#include #include -#include #include #include +#include namespace faabric::mpi { /* The MPI message buffer (MMB) keeps track of the asyncrhonous @@ -25,17 +26,20 @@ class MpiMessageBuffer { public: int requestId = -1; - std::shared_ptr msg = nullptr; + std::shared_ptr msg = nullptr; int sendRank = -1; int recvRank = -1; uint8_t* buffer = nullptr; faabric_datatype_t* dataType = nullptr; int count = -1; - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL; + MpiMessageType messageType = MpiMessageType::NORMAL; bool isAcknowledged() { return msg != nullptr; } - void acknowledge(std::shared_ptr msgIn) { msg = msgIn; } + void acknowledge(const MpiMessage& msgIn) + { + msg = std::make_shared(msgIn); + } }; /* Interface to query the buffer size */ diff --git a/include/faabric/mpi/MpiWorld.h b/include/faabric/mpi/MpiWorld.h index 2402d2e36..97fb24f18 100644 --- a/include/faabric/mpi/MpiWorld.h +++ b/include/faabric/mpi/MpiWorld.h @@ -1,8 +1,8 @@ #pragma once +#include #include #include -#include #include #include #include @@ -26,10 +26,9 @@ namespace faabric::mpi { // ----------------------------------- // MPITOPTP - mocking at the MPI level won't be needed when using the PTP broker // as the broker already has mocking capabilities -std::vector> getMpiMockedMessages(int sendRank); +std::vector getMpiMockedMessages(int sendRank); -typedef faabric::util::SpinLockQueue> - InMemoryMpiQueue; +typedef faabric::util::SpinLockQueue InMemoryMpiQueue; class MpiWorld { @@ -73,21 +72,21 @@ class MpiWorld const uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL); + MpiMessageType messageType = MpiMessageType::NORMAL); int isend(int sendRank, int recvRank, const uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL); + MpiMessageType messageType = MpiMessageType::NORMAL); void broadcast(int rootRank, int thisRank, uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL); + MpiMessageType messageType = MpiMessageType::NORMAL); void recv(int sendRank, int recvRank, @@ -95,14 +94,14 @@ class MpiWorld faabric_datatype_t* dataType, int count, MPI_Status* status, - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL); + MpiMessageType messageType = MpiMessageType::NORMAL); int irecv(int sendRank, int recvRank, uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL); + MpiMessageType messageType = MpiMessageType::NORMAL); void awaitAsyncRequest(int requestId); @@ -185,8 +184,6 @@ class MpiWorld std::shared_ptr getLocalQueue(int sendRank, int recvRank); - long getLocalQueueSize(int sendRank, int recvRank); - void overrideHost(const std::string& newHost); double getWTime(); @@ -240,29 +237,36 @@ class MpiWorld void sendRemoteMpiMessage(std::string dstHost, int sendRank, int recvRank, - const std::shared_ptr& msg); + const MpiMessage& msg); - std::shared_ptr recvRemoteMpiMessage(int sendRank, - int recvRank); + MpiMessage recvRemoteMpiMessage(int sendRank, int recvRank); // Support for asyncrhonous communications std::shared_ptr getUnackedMessageBuffer(int sendRank, int recvRank); - std::shared_ptr recvBatchReturnLast(int sendRank, - int recvRank, - int batchSize = 0); + MpiMessage recvBatchReturnLast(int sendRank, + int recvRank, + int batchSize = 0); /* Helper methods */ void checkRanksRange(int sendRank, int recvRank); // Abstraction of the bulk of the recv work, shared among various functions - void doRecv(std::shared_ptr& m, + void doRecv(const MpiMessage& m, + uint8_t* buffer, + faabric_datatype_t* dataType, + int count, + MPI_Status* status, + MpiMessageType messageType = MpiMessageType::NORMAL); + + // Abstraction of the bulk of the recv work, shared among various functions + void doRecv(std::unique_ptr m, uint8_t* buffer, faabric_datatype_t* dataType, int count, MPI_Status* status, - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL); + MpiMessageType messageType = MpiMessageType::NORMAL); }; } diff --git a/include/faabric/util/queue.h b/include/faabric/util/queue.h index 28462a32b..9f9e2f164 100644 --- a/include/faabric/util/queue.h +++ b/include/faabric/util/queue.h @@ -222,24 +222,34 @@ template class SpinLockQueue { public: - void enqueue(T& value, long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS) { - while (!mq.push(value)) { ; }; + void enqueue(T& value, long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS) + { + while (!mq.push(value)) { + ; + }; } - T dequeue(long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS) { + T dequeue(long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS) + { T value; - while (!mq.pop(value)) { ; } + while (!mq.pop(value)) { + ; + } return value; } - long size() { + long size() + { throw std::runtime_error("Size for fast queue unimplemented!"); } - void drain() { - while (mq.pop()) { ; } + void drain() + { + while (mq.pop()) { + ; + } } void reset() { ; } diff --git a/src/mpi/CMakeLists.txt b/src/mpi/CMakeLists.txt index dfd434c87..3ac5b98c9 100644 --- a/src/mpi/CMakeLists.txt +++ b/src/mpi/CMakeLists.txt @@ -38,32 +38,12 @@ endif() # ----------------------------------------------- if (NOT ("${CMAKE_PROJECT_NAME}" STREQUAL "faabricmpi")) - # Generate protobuf headers - set(MPI_PB_HEADER_COPIED "${FAABRIC_INCLUDE_DIR}/faabric/mpi/mpi.pb.h") - - protobuf_generate_cpp(MPI_PB_SRC MPI_PB_HEADER mpi.proto) - - # Copy the generated headers into place - add_custom_command( - OUTPUT "${MPI_PB_HEADER_COPIED}" - DEPENDS "${MPI_PB_HEADER}" - COMMAND ${CMAKE_COMMAND} - ARGS -E copy ${MPI_PB_HEADER} ${FAABRIC_INCLUDE_DIR}/faabric/mpi/ - ) - - add_custom_target( - mpi_pbh_copied - DEPENDS ${MPI_PB_HEADER_COPIED} - ) - - add_dependencies(faabric_common_dependencies mpi_pbh_copied) - faabric_lib(mpi MpiContext.cpp + MpiMessage.cpp MpiMessageBuffer.cpp MpiWorld.cpp MpiWorldRegistry.cpp - ${MPI_PB_SRC} ) target_link_libraries(mpi PRIVATE diff --git a/src/mpi/MpiMessage.cpp b/src/mpi/MpiMessage.cpp new file mode 100644 index 000000000..57ee8c85e --- /dev/null +++ b/src/mpi/MpiMessage.cpp @@ -0,0 +1,36 @@ +#include +#include + +#include +#include +#include + +namespace faabric::mpi { + +void parseMpiMsg(const std::vector& bytes, MpiMessage* msg) +{ + assert(msg != nullptr); + assert(bytes.size() >= sizeof(MpiMessage)); + std::memcpy(msg, bytes.data(), sizeof(MpiMessage)); + size_t thisPayloadSize = bytes.size() - sizeof(MpiMessage); + assert(thisPayloadSize == payloadSize(*msg)); + + if (thisPayloadSize == 0) { + msg->buffer = nullptr; + return; + } + + msg->buffer = faabric::util::malloc(thisPayloadSize); + std::memcpy( + msg->buffer, bytes.data() + sizeof(MpiMessage), thisPayloadSize); +} + +void serializeMpiMsg(std::vector& buffer, const MpiMessage& msg) +{ + std::memcpy(buffer.data(), &msg, sizeof(MpiMessage)); + size_t payloadSz = payloadSize(msg); + if (payloadSz > 0 && msg.buffer != nullptr) { + std::memcpy(buffer.data() + sizeof(MpiMessage), msg.buffer, payloadSz); + } +} +} diff --git a/src/mpi/MpiWorld.cpp b/src/mpi/MpiWorld.cpp index 9bca5dcb1..e1ec99e67 100644 --- a/src/mpi/MpiWorld.cpp +++ b/src/mpi/MpiWorld.cpp @@ -1,6 +1,6 @@ #include +#include #include -#include #include #include #include @@ -34,10 +34,9 @@ static std::mutex mockMutex; // The identifier in this map is the sending rank. For the receiver's rank // we can inspect the MPIMessage object -static std::map>> - mpiMockedMessages; +static std::map> mpiMockedMessages; -std::vector> getMpiMockedMessages(int sendRank) +std::vector getMpiMockedMessages(int sendRank) { faabric::util::UniqueLock lock(mockMutex); return mpiMockedMessages[sendRank]; @@ -53,12 +52,12 @@ MpiWorld::MpiWorld() void MpiWorld::sendRemoteMpiMessage(std::string dstHost, int sendRank, int recvRank, - const std::shared_ptr& msg) + const MpiMessage& msg) { - std::string serialisedBuffer; - if (!msg->SerializeToString(&serialisedBuffer)) { - throw std::runtime_error("Error serialising message"); - } + // Serialise + std::vector serialisedBuffer(msgSize(msg)); + serializeMpiMsg(serialisedBuffer, msg); + try { broker.sendMessage( thisRankMsg->groupid(), @@ -79,8 +78,7 @@ void MpiWorld::sendRemoteMpiMessage(std::string dstHost, } } -std::shared_ptr MpiWorld::recvRemoteMpiMessage(int sendRank, - int recvRank) +MpiMessage MpiWorld::recvRemoteMpiMessage(int sendRank, int recvRank) { std::vector msg; try { @@ -95,8 +93,12 @@ std::shared_ptr MpiWorld::recvRemoteMpiMessage(int sendRank, recvRank); throw e; } - PARSE_MSG(MPIMessage, msg.data(), msg.size()); - return std::make_shared(parsedMsg); + + // TODO(mpi-opt): make sure we minimze copies here + MpiMessage parsedMsg; + parseMpiMsg(msg, &parsedMsg); + + return parsedMsg; } std::shared_ptr MpiWorld::getUnackedMessageBuffer( @@ -447,7 +449,7 @@ int MpiWorld::isend(int sendRank, const uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType) + MpiMessageType messageType) { int requestId = (int)faabric::util::generateGid(); iSendRequests.insert(requestId); @@ -462,7 +464,7 @@ int MpiWorld::irecv(int sendRank, uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType) + MpiMessageType messageType) { int requestId = (int)faabric::util::generateGid(); reqIdToRanks.try_emplace(requestId, sendRank, recvRank); @@ -489,7 +491,7 @@ void MpiWorld::send(int sendRank, const uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType) + MpiMessageType messageType) { // Sanity-check input parameters checkRanksRange(sendRank, recvRank); @@ -506,45 +508,39 @@ void MpiWorld::send(int sendRank, // Generate a message ID int msgId = (localMsgCount + 1) % INT32_MAX; - // Create the message - auto m = std::make_shared(); - m->set_id(msgId); - m->set_worldid(id); - m->set_sender(sendRank); - m->set_destination(recvRank); - m->set_type(dataType->id); - m->set_count(count); - m->set_messagetype(messageType); - - // Set up message data - bool mustSendData = count > 0 && buffer != nullptr; + MpiMessage msg = { .id = msgId, + .worldId = id, + .sendRank = sendRank, + .recvRank = recvRank, + .typeSize = dataType->size, + .count = count, + .messageType = messageType, + .buffer = (void*)buffer }; // Mock the message sending in tests if (faabric::util::isMockMode()) { - mpiMockedMessages[sendRank].push_back(m); + mpiMockedMessages[sendRank].push_back(msg); return; } // Dispatch the message locally or globally if (isLocal) { - if (mustSendData) { + // Take control over the buffer data if we are gonna move it to + // the in-memory queues for local messaging + if (count > 0 && buffer != nullptr) { void* bufferPtr = faabric::util::malloc(count * dataType->size); std::memcpy(bufferPtr, buffer, count * dataType->size); - m->set_bufferptr((uint64_t)bufferPtr); + msg.buffer = bufferPtr; } SPDLOG_TRACE( "MPI - send {} -> {} ({})", sendRank, recvRank, messageType); - getLocalQueue(sendRank, recvRank)->enqueue(m); + getLocalQueue(sendRank, recvRank)->enqueue(msg); } else { - if (mustSendData) { - m->set_buffer(buffer, dataType->size * count); - } - SPDLOG_TRACE( "MPI - send remote {} -> {} ({})", sendRank, recvRank, messageType); - sendRemoteMpiMessage(otherHost, sendRank, recvRank, m); + sendRemoteMpiMessage(otherHost, sendRank, recvRank, msg); } /* 02/05/2022 - The following bit of code fails randomly with a protobuf @@ -572,7 +568,7 @@ void MpiWorld::recv(int sendRank, faabric_datatype_t* dataType, int count, MPI_Status* status, - MPIMessage::MPIMessageType messageType) + MpiMessageType messageType) { // Sanity-check input parameters checkRanksRange(sendRank, recvRank); @@ -582,54 +578,47 @@ void MpiWorld::recv(int sendRank, return; } - // Recv message from underlying transport - std::shared_ptr m = recvBatchReturnLast(sendRank, recvRank); + auto msg = recvBatchReturnLast(sendRank, recvRank); - // Do the processing - doRecv(m, buffer, dataType, count, status, messageType); + doRecv(std::move(msg), buffer, dataType, count, status, messageType); } -void MpiWorld::doRecv(std::shared_ptr& m, +void MpiWorld::doRecv(const MpiMessage& m, uint8_t* buffer, faabric_datatype_t* dataType, int count, MPI_Status* status, - MPIMessage::MPIMessageType messageType) + MpiMessageType messageType) { // Assert message integrity // Note - this checks won't happen in Release builds - if (m->messagetype() != messageType) { + if (m.messageType != messageType) { SPDLOG_ERROR("Different message types (got: {}, expected: {})", - m->messagetype(), + m.messageType, messageType); } - assert(m->messagetype() == messageType); - assert(m->count() <= count); - - const std::string otherHost = getHostForRank(m->destination()); - bool isLocal = - getHostForRank(m->destination()) == getHostForRank(m->sender()); - - if (m->count() > 0) { - if (isLocal) { - // Make sure we do not overflow the recepient buffer - auto bytesToCopy = std::min(m->count() * dataType->size, - count * dataType->size); - std::memcpy(buffer, (void*)m->bufferptr(), bytesToCopy); - faabric::util::free((void*)m->bufferptr()); - } else { - // TODO - avoid copy here - std::move(m->buffer().begin(), m->buffer().end(), buffer); - } + assert(m.messageType == messageType); + assert(m.count <= count); + + // We must copy the data into the application-provided buffer + if (m.count > 0 && m.buffer != nullptr) { + // Make sure we do not overflow the recepient buffer + auto bytesToCopy = + std::min(m.count * dataType->size, count * dataType->size); + std::memcpy(buffer, m.buffer, bytesToCopy); + + // This buffer has been malloc-ed either as part of a local `send` + // or as part of a remote `parseMpiMsg` + faabric::util::free((void*)m.buffer); } // Set status values if required if (status != nullptr) { - status->MPI_SOURCE = m->sender(); + status->MPI_SOURCE = m.sendRank; status->MPI_ERROR = MPI_SUCCESS; // Take the message size here as the receive count may be larger - status->bytesSize = m->count() * dataType->size; + status->bytesSize = m.count * dataType->size; // TODO - thread through tag status->MPI_TAG = -1; @@ -667,14 +656,14 @@ void MpiWorld::sendRecv(uint8_t* sendBuffer, recvBuffer, recvDataType, recvCount, - MPIMessage::SENDRECV); + MpiMessageType::SENDRECV); // Then send the message send(myRank, sendRank, sendBuffer, sendDataType, sendCount, - MPIMessage::SENDRECV); + MpiMessageType::SENDRECV); // And wait awaitAsyncRequest(recvId); } @@ -684,7 +673,7 @@ void MpiWorld::broadcast(int sendRank, uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType) + MpiMessageType messageType) { SPDLOG_TRACE("MPI - bcast {} -> {}", sendRank, recvRank); @@ -795,7 +784,7 @@ void MpiWorld::scatter(int sendRank, startPtr, sendType, sendCount, - MPIMessage::SCATTER); + MpiMessageType::SCATTER); } } } else { @@ -806,7 +795,7 @@ void MpiWorld::scatter(int sendRank, recvType, recvCount, nullptr, - MPIMessage::SCATTER); + MpiMessageType::SCATTER); } } @@ -880,7 +869,7 @@ void MpiWorld::gather(int sendRank, recvType, recvCount, nullptr, - MPIMessage::GATHER); + MpiMessageType::GATHER); } } } else { @@ -894,7 +883,7 @@ void MpiWorld::gather(int sendRank, recvType, recvCount * it.second.size(), nullptr, - MPIMessage::GATHER); + MpiMessageType::GATHER); // Copy each received chunk to its offset for (int r = 0; r < it.second.size(); r++) { @@ -924,7 +913,7 @@ void MpiWorld::gather(int sendRank, sendType, sendCount, nullptr, - MPIMessage::GATHER); + MpiMessageType::GATHER); } } @@ -934,7 +923,7 @@ void MpiWorld::gather(int sendRank, rankData.get(), sendType, sendCount * ranksForHost[thisHost].size(), - MPIMessage::GATHER); + MpiMessageType::GATHER); } else if (isLocalLeader && isLocalGather) { // Scenario 3 @@ -943,7 +932,7 @@ void MpiWorld::gather(int sendRank, sendBuffer + sendBufferOffset, sendType, sendCount, - MPIMessage::GATHER); + MpiMessageType::GATHER); } else if (!isLocalLeader && !isLocalGather) { // Scenario 4 send(sendRank, @@ -951,7 +940,7 @@ void MpiWorld::gather(int sendRank, sendBuffer + sendBufferOffset, sendType, sendCount, - MPIMessage::GATHER); + MpiMessageType::GATHER); } else if (!isLocalLeader && isLocalGather) { // Scenario 5 send(sendRank, @@ -959,7 +948,7 @@ void MpiWorld::gather(int sendRank, sendBuffer + sendBufferOffset, sendType, sendCount, - MPIMessage::GATHER); + MpiMessageType::GATHER); } else { SPDLOG_ERROR("Don't know how to gather rank's data."); SPDLOG_ERROR("- sendRank: {}\n- recvRank: {}\n- isGatherReceiver: " @@ -1001,7 +990,7 @@ void MpiWorld::allGather(int rank, // Do a broadcast with a hard-coded root broadcast( - root, rank, recvBuffer, recvType, fullCount, MPIMessage::ALLGATHER); + root, rank, recvBuffer, recvType, fullCount, MpiMessageType::ALLGATHER); } void MpiWorld::awaitAsyncRequest(int requestId) @@ -1033,10 +1022,10 @@ void MpiWorld::awaitAsyncRequest(int requestId) std::list::iterator msgIt = umb->getRequestPendingMsg(requestId); - std::shared_ptr m; + MpiMessage m; if (msgIt->msg != nullptr) { // This id has already been acknowledged by a recv call, so do the recv - m = msgIt->msg; + m = *(msgIt->msg); } else { // We need to acknowledge all messages not acknowledged from the // begining until us @@ -1094,7 +1083,7 @@ void MpiWorld::reduce(int sendRank, datatype, count, nullptr, - MPIMessage::REDUCE); + MpiMessageType::REDUCE); op_reduce( operation, datatype, count, rankData.get(), recvBuffer); @@ -1108,7 +1097,7 @@ void MpiWorld::reduce(int sendRank, datatype, count, nullptr, - MPIMessage::REDUCE); + MpiMessageType::REDUCE); op_reduce( operation, datatype, count, rankData.get(), recvBuffer); @@ -1138,7 +1127,7 @@ void MpiWorld::reduce(int sendRank, datatype, count, nullptr, - MPIMessage::REDUCE); + MpiMessageType::REDUCE); op_reduce(operation, datatype, @@ -1152,7 +1141,7 @@ void MpiWorld::reduce(int sendRank, sendBufferCopy.get(), datatype, count, - MPIMessage::REDUCE); + MpiMessageType::REDUCE); } else { // Send to the receiver rank send(sendRank, @@ -1160,7 +1149,7 @@ void MpiWorld::reduce(int sendRank, sendBuffer, datatype, count, - MPIMessage::REDUCE); + MpiMessageType::REDUCE); } } else { // If we are neither the receiver of the reduce nor a local leader, we @@ -1175,7 +1164,7 @@ void MpiWorld::reduce(int sendRank, sendBuffer, datatype, count, - MPIMessage::REDUCE); + MpiMessageType::REDUCE); } } @@ -1191,7 +1180,7 @@ void MpiWorld::allReduce(int rank, reduce(rank, 0, sendBuffer, recvBuffer, datatype, count, operation); // Second, 0 broadcasts the result to all ranks - broadcast(0, rank, recvBuffer, datatype, count, MPIMessage::ALLREDUCE); + broadcast(0, rank, recvBuffer, datatype, count, MpiMessageType::ALLREDUCE); } void MpiWorld::op_reduce(faabric_op_t* operation, @@ -1350,14 +1339,14 @@ void MpiWorld::scan(int rank, datatype, count, nullptr, - MPIMessage::SCAN); + MpiMessageType::SCAN); // Reduce with our own value op_reduce(operation, datatype, count, currentAcc.get(), recvBuffer); } // If not the last process, send to the next one if (rank < this->size - 1) { - send(rank, rank + 1, recvBuffer, MPI_INT, count, MPIMessage::SCAN); + send(rank, rank + 1, recvBuffer, MPI_INT, count, MpiMessageType::SCAN); } } @@ -1385,7 +1374,12 @@ void MpiWorld::allToAll(int rank, sendChunk, sendChunk + sendOffset, recvBuffer + rankOffset); } else { // Send message to other rank - send(rank, r, sendChunk, sendType, sendCount, MPIMessage::ALLTOALL); + send(rank, + r, + sendChunk, + sendType, + sendCount, + MpiMessageType::ALLTOALL); } } @@ -1405,7 +1399,7 @@ void MpiWorld::allToAll(int rank, recvType, recvCount, nullptr, - MPIMessage::ALLTOALL); + MpiMessageType::ALLTOALL); } } @@ -1439,17 +1433,17 @@ void MpiWorld::barrier(int thisRank) // Await messages from all others for (int r = 1; r < size; r++) { MPI_Status s{}; - recv(r, 0, nullptr, MPI_INT, 0, &s, MPIMessage::BARRIER_JOIN); + recv(r, 0, nullptr, MPI_INT, 0, &s, MpiMessageType::BARRIER_JOIN); SPDLOG_TRACE("MPI - recv barrier join {}", s.MPI_SOURCE); } } else { // Tell the root that we're waiting SPDLOG_TRACE("MPI - barrier join {}", thisRank); - send(thisRank, 0, nullptr, MPI_INT, 0, MPIMessage::BARRIER_JOIN); + send(thisRank, 0, nullptr, MPI_INT, 0, MpiMessageType::BARRIER_JOIN); } // Rank 0 broadcasts that the barrier is done (the others block here) - broadcast(0, thisRank, nullptr, MPI_INT, 0, MPIMessage::BARRIER_DONE); + broadcast(0, thisRank, nullptr, MPI_INT, 0, MpiMessageType::BARRIER_DONE); SPDLOG_TRACE("MPI - barrier done {}", thisRank); } @@ -1479,9 +1473,10 @@ void MpiWorld::initLocalQueues() } } -std::shared_ptr MpiWorld::recvBatchReturnLast(int sendRank, - int recvRank, - int batchSize) +// TODO(mpi-opt): double-check that the fast (no-async) path is fast +MpiMessage MpiWorld::recvBatchReturnLast(int sendRank, + int recvRank, + int batchSize) { std::shared_ptr umb = getUnackedMessageBuffer(sendRank, recvRank); @@ -1501,7 +1496,7 @@ std::shared_ptr MpiWorld::recvBatchReturnLast(int sendRank, // Recv message: first we receive all messages for which there is an id // in the unacknowleged buffer but no msg. Note that these messages // (batchSize - 1) were `irecv`-ed before ours. - std::shared_ptr ourMsg; + MpiMessage ourMsg; auto msgIt = umb->getFirstNullMsg(); if (isLocal) { // First receive messages that happened before us @@ -1567,13 +1562,6 @@ int MpiWorld::getIndexForRanks(int sendRank, int recvRank) const return index; } -long MpiWorld::getLocalQueueSize(int sendRank, int recvRank) -{ - const std::shared_ptr& queue = - getLocalQueue(sendRank, recvRank); - return queue->size(); -} - double MpiWorld::getWTime() { double t = faabric::util::getTimeDiffMillis(creationTime); diff --git a/src/mpi/mpi.proto b/src/mpi/mpi.proto deleted file mode 100644 index 80a690820..000000000 --- a/src/mpi/mpi.proto +++ /dev/null @@ -1,35 +0,0 @@ -syntax = "proto3"; - -package faabric.mpi; - -message MPIMessage { - enum MPIMessageType { - NORMAL = 0; - BARRIER_JOIN = 1; - BARRIER_DONE = 2; - SCATTER = 3; - GATHER = 4; - ALLGATHER = 5; - REDUCE = 6; - SCAN = 7; - ALLREDUCE = 8; - ALLTOALL = 9; - SENDRECV = 10; - BROADCAST = 11; - }; - - MPIMessageType messageType = 1; - - int32 id = 2; - int32 worldId = 3; - int32 sender = 4; - int32 destination = 5; - int32 type = 6; - int32 count = 7; - - // For remote messaging - optional bytes buffer = 8; - - // For local messaging - optional int64 bufferPtr = 9; -} diff --git a/tests/dist/mpi/mpi_native.cpp b/tests/dist/mpi/mpi_native.cpp index d41235940..a499fb357 100644 --- a/tests/dist/mpi/mpi_native.cpp +++ b/tests/dist/mpi/mpi_native.cpp @@ -2,9 +2,9 @@ #include #include +#include #include #include -#include #include #include #include @@ -126,7 +126,7 @@ int MPI_Send(const void* buf, (uint8_t*)buf, datatype, count, - MPIMessage::NORMAL); + MpiMessageType::NORMAL); return MPI_SUCCESS; } @@ -159,7 +159,7 @@ int MPI_Recv(void* buf, datatype, count, status, - MPIMessage::NORMAL); + MpiMessageType::NORMAL); return MPI_SUCCESS; } @@ -245,7 +245,7 @@ int MPI_Bcast(void* buffer, int rank = executingContext.getRank(); world.broadcast( - root, rank, (uint8_t*)buffer, datatype, count, MPIMessage::BROADCAST); + root, rank, (uint8_t*)buffer, datatype, count, MpiMessageType::BROADCAST); return MPI_SUCCESS; } diff --git a/tests/test/mpi/test_mpi_message.cpp b/tests/test/mpi/test_mpi_message.cpp new file mode 100644 index 000000000..9c79f8d3b --- /dev/null +++ b/tests/test/mpi/test_mpi_message.cpp @@ -0,0 +1,123 @@ +#include + +#include +#include + +#include + +using namespace faabric::mpi; + +namespace tests { + +bool areMpiMsgEqual(const MpiMessage& msgA, const MpiMessage& msgB) +{ + auto sizeA = msgSize(msgA); + auto sizeB = msgSize(msgB); + + if (sizeA != sizeB) { + return false; + } + + // First, compare the message body (excluding the pointer, which we + // know is at the end) + if (std::memcmp(&msgA, &msgB, sizeof(MpiMessage) - sizeof(void*)) != 0) { + return false; + } + + // Check that if one buffer points to null, so must do the other + if (msgA.buffer == nullptr || msgB.buffer == nullptr) { + return msgA.buffer == msgB.buffer; + } + + // If none points to null, they must point to the same data + auto payloadSizeA = payloadSize(msgA); + auto payloadSizeB = payloadSize(msgB); + // Assert, as this should pass given the previous comparisons + assert(payloadSizeA == payloadSizeB); + + return std::memcmp(msgA.buffer, msgB.buffer, payloadSizeA) == 0; +} + +TEST_CASE("Test getting a message size", "[mpi]") +{ + MpiMessage msg = { .id = 1, + .worldId = 3, + .sendRank = 3, + .recvRank = 7, + .typeSize = 1, + .count = 3, + .messageType = MpiMessageType::NORMAL }; + + size_t expectedMsgSize = 0; + size_t expectedPayloadSize = 0; + + SECTION("Empty message") + { + msg.buffer = nullptr; + msg.count = 0; + expectedMsgSize = sizeof(MpiMessage); + expectedPayloadSize = 0; + } + + SECTION("Non-empty message") + { + std::vector nums = { 1, 2, 3, 4, 5, 6, 6 }; + msg.count = nums.size(); + msg.typeSize = sizeof(int); + msg.buffer = faabric::util::malloc(msg.count * msg.typeSize); + std::memcpy(msg.buffer, nums.data(), nums.size() * sizeof(int)); + + expectedPayloadSize = sizeof(int) * nums.size(); + expectedMsgSize = sizeof(MpiMessage) + expectedPayloadSize; + } + + REQUIRE(expectedMsgSize == msgSize(msg)); + REQUIRE(expectedPayloadSize == payloadSize(msg)); + + if (msg.buffer != nullptr) { + faabric::util::free(msg.buffer); + } +} + +TEST_CASE("Test (de)serialising an MPI message", "[mpi]") +{ + MpiMessage msg = { .id = 1, + .worldId = 3, + .sendRank = 3, + .recvRank = 7, + .typeSize = 1, + .count = 3, + .messageType = MpiMessageType::NORMAL }; + + SECTION("Empty message") + { + msg.count = 0; + msg.buffer = nullptr; + } + + SECTION("Non-empty message") + { + std::vector nums = { 1, 2, 3, 4, 5, 6, 6 }; + msg.count = nums.size(); + msg.typeSize = sizeof(int); + msg.buffer = faabric::util::malloc(msg.count * msg.typeSize); + std::memcpy(msg.buffer, nums.data(), nums.size() * sizeof(int)); + } + + // Serialise and de-serialise + std::vector buffer(msgSize(msg)); + serializeMpiMsg(buffer, msg); + + MpiMessage parsedMsg; + parseMpiMsg(buffer, &parsedMsg); + + REQUIRE(areMpiMsgEqual(msg, parsedMsg)); + + if (msg.buffer != nullptr) { + faabric::util::free(msg.buffer); + } + if (parsedMsg.buffer != nullptr) { + faabric::util::free(parsedMsg.buffer); + } +} +} diff --git a/tests/test/mpi/test_mpi_message_buffer.cpp b/tests/test/mpi/test_mpi_message_buffer.cpp index 1674172fd..710a3c259 100644 --- a/tests/test/mpi/test_mpi_message_buffer.cpp +++ b/tests/test/mpi/test_mpi_message_buffer.cpp @@ -21,7 +21,7 @@ MpiMessageBuffer::PendingAsyncMpiMessage genRandomArguments( pendingMsg.requestId = requestId; if (!nullMsg) { - pendingMsg.msg = std::make_shared(); + pendingMsg.msg = std::make_shared(); } return pendingMsg; diff --git a/tests/test/mpi/test_mpi_world.cpp b/tests/test/mpi/test_mpi_world.cpp index 8c1aca149..05387cfa1 100644 --- a/tests/test/mpi/test_mpi_world.cpp +++ b/tests/test/mpi/test_mpi_world.cpp @@ -212,23 +212,22 @@ TEST_CASE_METHOD(MpiBaseTestFixture, "Test local barrier", "[mpi]") world.destroy(); } -void checkMessage(MPIMessage& actualMessage, +void checkMessage(MpiMessage& actualMessage, int worldId, int senderRank, int destRank, const std::vector& data) { // Check the message contents - REQUIRE(actualMessage.worldid() == worldId); - REQUIRE(actualMessage.count() == data.size()); - REQUIRE(actualMessage.destination() == destRank); - REQUIRE(actualMessage.sender() == senderRank); - REQUIRE(actualMessage.type() == FAABRIC_INT); + REQUIRE(actualMessage.worldId == worldId); + REQUIRE(actualMessage.count == data.size()); + REQUIRE(actualMessage.recvRank == destRank); + REQUIRE(actualMessage.sendRank == senderRank); + REQUIRE(actualMessage.typeSize == FAABRIC_INT); // Check data - const auto* rawInts = - reinterpret_cast(actualMessage.buffer().c_str()); - size_t nInts = actualMessage.buffer().size() / sizeof(int); + const auto* rawInts = reinterpret_cast(actualMessage.buffer); + size_t nInts = payloadSize(actualMessage) / sizeof(int); std::vector actualData(rawInts, rawInts + nInts); REQUIRE(actualData == data); } @@ -396,10 +395,10 @@ TEST_CASE_METHOD(MpiTestFixture, "Test send/recv message with no data", "[mpi]") SECTION("Check on queue") { // Check message content - MPIMessage actualMessage = - *(world.getLocalQueue(rankA1, rankA2)->dequeue()); - REQUIRE(actualMessage.count() == 0); - REQUIRE(actualMessage.type() == FAABRIC_INT); + MpiMessage actualMessage = + world.getLocalQueue(rankA1, rankA2)->dequeue(); + REQUIRE(actualMessage.count == 0); + REQUIRE(actualMessage.typeSize == FAABRIC_INT); } SECTION("Check receiving with null ptr") @@ -502,7 +501,7 @@ TEST_CASE_METHOD(MpiTestFixture, "Test collective messaging locally", "[mpi]") BYTES(messageData.data()), MPI_INT, messageData.size(), - MPIMessage::BROADCAST); + MpiMessageType::BROADCAST); // Recv on all non-root ranks for (int rank = 0; rank < worldSize; rank++) { @@ -515,7 +514,7 @@ TEST_CASE_METHOD(MpiTestFixture, "Test collective messaging locally", "[mpi]") BYTES(actual.data()), MPI_INT, 3, - MPIMessage::BROADCAST); + MpiMessageType::BROADCAST); REQUIRE(actual == messageData); } } diff --git a/tests/test/mpi/test_multiple_mpi_worlds.cpp b/tests/test/mpi/test_multiple_mpi_worlds.cpp index 735556f6e..a1062e81b 100644 --- a/tests/test/mpi/test_multiple_mpi_worlds.cpp +++ b/tests/test/mpi/test_multiple_mpi_worlds.cpp @@ -155,29 +155,6 @@ TEST_CASE_METHOD(MultiWorldMpiTestFixture, worldB.send( rankA1, rankA2, BYTES(messageData.data()), MPI_INT, messageData.size()); - SECTION("Test queueing") - { - // Check for world A - REQUIRE(worldA.getLocalQueueSize(rankA1, rankA2) == 1); - REQUIRE(worldA.getLocalQueueSize(rankA2, rankA1) == 0); - REQUIRE(worldA.getLocalQueueSize(rankA1, 0) == 0); - REQUIRE(worldA.getLocalQueueSize(rankA2, 0) == 0); - const std::shared_ptr& queueA2 = - worldA.getLocalQueue(rankA1, rankA2); - MPIMessage actualMessage = *(queueA2->dequeue()); - // checkMessage(actualMessage, worldId, rankA1, rankA2, messageData); - - // Check for world B - REQUIRE(worldB.getLocalQueueSize(rankA1, rankA2) == 1); - REQUIRE(worldB.getLocalQueueSize(rankA2, rankA1) == 0); - REQUIRE(worldB.getLocalQueueSize(rankA1, 0) == 0); - REQUIRE(worldB.getLocalQueueSize(rankA2, 0) == 0); - const std::shared_ptr& queueA2B = - worldB.getLocalQueue(rankA1, rankA2); - actualMessage = *(queueA2B->dequeue()); - // checkMessage(actualMessage, worldId, rankA1, rankA2, messageData); - } - SECTION("Test recv") { MPI_Status status{}; diff --git a/tests/test/mpi/test_remote_mpi_worlds.cpp b/tests/test/mpi/test_remote_mpi_worlds.cpp index 1e56b48b1..54662929f 100644 --- a/tests/test/mpi/test_remote_mpi_worlds.cpp +++ b/tests/test/mpi/test_remote_mpi_worlds.cpp @@ -21,12 +21,11 @@ using namespace faabric::mpi; using namespace faabric::scheduler; namespace tests { -std::set getReceiversFromMessages( - std::vector> msgs) +std::set getReceiversFromMessages(std::vector msgs) { std::set receivers; for (const auto& msg : msgs) { - receivers.insert(msg->destination()); + receivers.insert(msg.recvRank); } return receivers; @@ -108,14 +107,14 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, BYTES(messageData.data()), MPI_INT, messageData.size(), - MPIMessage::BROADCAST); + MpiMessageType::BROADCAST); } else { otherWorld.broadcast(sendRank, recvRank, BYTES(messageData.data()), MPI_INT, messageData.size(), - MPIMessage::BROADCAST); + MpiMessageType::BROADCAST); } auto msgs = getMpiMockedMessages(recvRank); REQUIRE(msgs.size() == expectedNumMsg); @@ -219,12 +218,11 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, thisWorld.destroy(); } -std::set getMsgCountsFromMessages( - std::vector> msgs) +std::set getMsgCountsFromMessages(std::vector msgs) { std::set counts; for (const auto& msg : msgs) { - counts.insert(msg->count()); + counts.insert(msg.count); } return counts;