From 26568f78c368256eae5aa3cb74c94bc737bbd85a Mon Sep 17 00:00:00 2001 From: Carlos Segarra <carlos@carlossegarra.com> Date: Thu, 22 Feb 2024 17:51:27 +0000 Subject: [PATCH 1/2] mpi: move MpiMessage from protobuf to struct --- include/faabric/mpi/MpiMessage.h | 49 +++++ include/faabric/mpi/MpiMessageBuffer.h | 12 +- include/faabric/mpi/MpiWorld.h | 40 ++-- src/mpi/CMakeLists.txt | 22 +-- src/mpi/MpiMessage.cpp | 36 ++++ src/mpi/MpiWorld.cpp | 197 ++++++++++---------- src/mpi/mpi.proto | 35 ---- tests/dist/mpi/mpi_native.cpp | 8 +- tests/test/mpi/test_mpi_message.cpp | 123 ++++++++++++ tests/test/mpi/test_mpi_message_buffer.cpp | 2 +- tests/test/mpi/test_mpi_world.cpp | 29 ++- tests/test/mpi/test_multiple_mpi_worlds.cpp | 4 +- tests/test/mpi/test_remote_mpi_worlds.cpp | 14 +- 13 files changed, 364 insertions(+), 207 deletions(-) create mode 100644 include/faabric/mpi/MpiMessage.h create mode 100644 src/mpi/MpiMessage.cpp delete mode 100644 src/mpi/mpi.proto create mode 100644 tests/test/mpi/test_mpi_message.cpp diff --git a/include/faabric/mpi/MpiMessage.h b/include/faabric/mpi/MpiMessage.h new file mode 100644 index 000000000..7c85fde48 --- /dev/null +++ b/include/faabric/mpi/MpiMessage.h @@ -0,0 +1,49 @@ +#pragma once + +#include <cstdint> +#include <vector> + +namespace faabric::mpi { + +enum MpiMessageType : int32_t +{ + NORMAL = 0, + BARRIER_JOIN = 1, + BARRIER_DONE = 2, + SCATTER = 3, + GATHER = 4, + ALLGATHER = 5, + REDUCE = 6, + SCAN = 7, + ALLREDUCE = 8, + ALLTOALL = 9, + SENDRECV = 10, + BROADCAST = 11, +}; + +struct MpiMessage +{ + int32_t id; + int32_t worldId; + int32_t sendRank; + int32_t recvRank; + int32_t typeSize; + int32_t count; + MpiMessageType messageType; + void* buffer; +}; + +inline size_t payloadSize(const MpiMessage& msg) +{ + return msg.typeSize * msg.count; +} + +inline size_t msgSize(const MpiMessage& msg) +{ + return sizeof(MpiMessage) + payloadSize(msg); +} + +void serializeMpiMsg(std::vector<uint8_t>& buffer, const MpiMessage& msg); + +void parseMpiMsg(const std::vector<uint8_t>& bytes, MpiMessage* msg); +} diff --git a/include/faabric/mpi/MpiMessageBuffer.h b/include/faabric/mpi/MpiMessageBuffer.h index 9fc67b644..c36f89887 100644 --- a/include/faabric/mpi/MpiMessageBuffer.h +++ b/include/faabric/mpi/MpiMessageBuffer.h @@ -1,8 +1,9 @@ +#include <faabric/mpi/MpiMessage.h> #include <faabric/mpi/mpi.h> -#include <faabric/mpi/mpi.pb.h> #include <iterator> #include <list> +#include <memory> namespace faabric::mpi { /* The MPI message buffer (MMB) keeps track of the asyncrhonous @@ -25,17 +26,20 @@ class MpiMessageBuffer { public: int requestId = -1; - std::shared_ptr<MPIMessage> msg = nullptr; + std::shared_ptr<MpiMessage> msg = nullptr; int sendRank = -1; int recvRank = -1; uint8_t* buffer = nullptr; faabric_datatype_t* dataType = nullptr; int count = -1; - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL; + MpiMessageType messageType = MpiMessageType::NORMAL; bool isAcknowledged() { return msg != nullptr; } - void acknowledge(std::shared_ptr<MPIMessage> msgIn) { msg = msgIn; } + void acknowledge(const MpiMessage& msgIn) + { + msg = std::make_shared<MpiMessage>(msgIn); + } }; /* Interface to query the buffer size */ diff --git a/include/faabric/mpi/MpiWorld.h b/include/faabric/mpi/MpiWorld.h index adee54137..8f9cb918c 100644 --- a/include/faabric/mpi/MpiWorld.h +++ b/include/faabric/mpi/MpiWorld.h @@ -1,8 +1,8 @@ #pragma once +#include <faabric/mpi/MpiMessage.h> #include <faabric/mpi/MpiMessageBuffer.h> #include <faabric/mpi/mpi.h> -#include <faabric/mpi/mpi.pb.h> #include <faabric/proto/faabric.pb.h> #include <faabric/scheduler/InMemoryMessageQueue.h> #include <faabric/transport/PointToPointBroker.h> @@ -26,10 +26,9 @@ namespace faabric::mpi { // ----------------------------------- // MPITOPTP - mocking at the MPI level won't be needed when using the PTP broker // as the broker already has mocking capabilities -std::vector<std::shared_ptr<MPIMessage>> getMpiMockedMessages(int sendRank); +std::vector<MpiMessage> getMpiMockedMessages(int sendRank); -typedef faabric::util::FixedCapacityQueue<std::shared_ptr<MPIMessage>> - InMemoryMpiQueue; +typedef faabric::util::FixedCapacityQueue<MpiMessage> InMemoryMpiQueue; class MpiWorld { @@ -73,21 +72,21 @@ class MpiWorld const uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL); + MpiMessageType messageType = MpiMessageType::NORMAL); int isend(int sendRank, int recvRank, const uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL); + MpiMessageType messageType = MpiMessageType::NORMAL); void broadcast(int rootRank, int thisRank, uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL); + MpiMessageType messageType = MpiMessageType::NORMAL); void recv(int sendRank, int recvRank, @@ -95,14 +94,14 @@ class MpiWorld faabric_datatype_t* dataType, int count, MPI_Status* status, - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL); + MpiMessageType messageType = MpiMessageType::NORMAL); int irecv(int sendRank, int recvRank, uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL); + MpiMessageType messageType = MpiMessageType::NORMAL); void awaitAsyncRequest(int requestId); @@ -240,29 +239,36 @@ class MpiWorld void sendRemoteMpiMessage(std::string dstHost, int sendRank, int recvRank, - const std::shared_ptr<MPIMessage>& msg); + const MpiMessage& msg); - std::shared_ptr<MPIMessage> recvRemoteMpiMessage(int sendRank, - int recvRank); + MpiMessage recvRemoteMpiMessage(int sendRank, int recvRank); // Support for asyncrhonous communications std::shared_ptr<MpiMessageBuffer> getUnackedMessageBuffer(int sendRank, int recvRank); - std::shared_ptr<MPIMessage> recvBatchReturnLast(int sendRank, - int recvRank, - int batchSize = 0); + MpiMessage recvBatchReturnLast(int sendRank, + int recvRank, + int batchSize = 0); /* Helper methods */ void checkRanksRange(int sendRank, int recvRank); // Abstraction of the bulk of the recv work, shared among various functions - void doRecv(std::shared_ptr<MPIMessage>& m, + void doRecv(const MpiMessage& m, uint8_t* buffer, faabric_datatype_t* dataType, int count, MPI_Status* status, - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL); + MpiMessageType messageType = MpiMessageType::NORMAL); + + // Abstraction of the bulk of the recv work, shared among various functions + void doRecv(std::unique_ptr<MpiMessage> m, + uint8_t* buffer, + faabric_datatype_t* dataType, + int count, + MPI_Status* status, + MpiMessageType messageType = MpiMessageType::NORMAL); }; } diff --git a/src/mpi/CMakeLists.txt b/src/mpi/CMakeLists.txt index dfd434c87..3ac5b98c9 100644 --- a/src/mpi/CMakeLists.txt +++ b/src/mpi/CMakeLists.txt @@ -38,32 +38,12 @@ endif() # ----------------------------------------------- if (NOT ("${CMAKE_PROJECT_NAME}" STREQUAL "faabricmpi")) - # Generate protobuf headers - set(MPI_PB_HEADER_COPIED "${FAABRIC_INCLUDE_DIR}/faabric/mpi/mpi.pb.h") - - protobuf_generate_cpp(MPI_PB_SRC MPI_PB_HEADER mpi.proto) - - # Copy the generated headers into place - add_custom_command( - OUTPUT "${MPI_PB_HEADER_COPIED}" - DEPENDS "${MPI_PB_HEADER}" - COMMAND ${CMAKE_COMMAND} - ARGS -E copy ${MPI_PB_HEADER} ${FAABRIC_INCLUDE_DIR}/faabric/mpi/ - ) - - add_custom_target( - mpi_pbh_copied - DEPENDS ${MPI_PB_HEADER_COPIED} - ) - - add_dependencies(faabric_common_dependencies mpi_pbh_copied) - faabric_lib(mpi MpiContext.cpp + MpiMessage.cpp MpiMessageBuffer.cpp MpiWorld.cpp MpiWorldRegistry.cpp - ${MPI_PB_SRC} ) target_link_libraries(mpi PRIVATE diff --git a/src/mpi/MpiMessage.cpp b/src/mpi/MpiMessage.cpp new file mode 100644 index 000000000..57ee8c85e --- /dev/null +++ b/src/mpi/MpiMessage.cpp @@ -0,0 +1,36 @@ +#include <faabric/mpi/MpiMessage.h> +#include <faabric/util/memory.h> + +#include <cassert> +#include <cstdint> +#include <cstring> + +namespace faabric::mpi { + +void parseMpiMsg(const std::vector<uint8_t>& bytes, MpiMessage* msg) +{ + assert(msg != nullptr); + assert(bytes.size() >= sizeof(MpiMessage)); + std::memcpy(msg, bytes.data(), sizeof(MpiMessage)); + size_t thisPayloadSize = bytes.size() - sizeof(MpiMessage); + assert(thisPayloadSize == payloadSize(*msg)); + + if (thisPayloadSize == 0) { + msg->buffer = nullptr; + return; + } + + msg->buffer = faabric::util::malloc(thisPayloadSize); + std::memcpy( + msg->buffer, bytes.data() + sizeof(MpiMessage), thisPayloadSize); +} + +void serializeMpiMsg(std::vector<uint8_t>& buffer, const MpiMessage& msg) +{ + std::memcpy(buffer.data(), &msg, sizeof(MpiMessage)); + size_t payloadSz = payloadSize(msg); + if (payloadSz > 0 && msg.buffer != nullptr) { + std::memcpy(buffer.data() + sizeof(MpiMessage), msg.buffer, payloadSz); + } +} +} diff --git a/src/mpi/MpiWorld.cpp b/src/mpi/MpiWorld.cpp index d50344c40..5b0887c7a 100644 --- a/src/mpi/MpiWorld.cpp +++ b/src/mpi/MpiWorld.cpp @@ -1,6 +1,6 @@ #include <faabric/batch-scheduler/SchedulingDecision.h> +#include <faabric/mpi/MpiMessage.h> #include <faabric/mpi/MpiWorld.h> -#include <faabric/mpi/mpi.pb.h> #include <faabric/planner/PlannerClient.h> #include <faabric/transport/macros.h> #include <faabric/util/ExecGraph.h> @@ -34,10 +34,9 @@ static std::mutex mockMutex; // The identifier in this map is the sending rank. For the receiver's rank // we can inspect the MPIMessage object -static std::map<int, std::vector<std::shared_ptr<MPIMessage>>> - mpiMockedMessages; +static std::map<int, std::vector<MpiMessage>> mpiMockedMessages; -std::vector<std::shared_ptr<MPIMessage>> getMpiMockedMessages(int sendRank) +std::vector<MpiMessage> getMpiMockedMessages(int sendRank) { faabric::util::UniqueLock lock(mockMutex); return mpiMockedMessages[sendRank]; @@ -53,12 +52,12 @@ MpiWorld::MpiWorld() void MpiWorld::sendRemoteMpiMessage(std::string dstHost, int sendRank, int recvRank, - const std::shared_ptr<MPIMessage>& msg) + const MpiMessage& msg) { - std::string serialisedBuffer; - if (!msg->SerializeToString(&serialisedBuffer)) { - throw std::runtime_error("Error serialising message"); - } + // Serialise + std::vector<uint8_t> serialisedBuffer(msgSize(msg)); + serializeMpiMsg(serialisedBuffer, msg); + try { broker.sendMessage( thisRankMsg->groupid(), @@ -79,8 +78,7 @@ void MpiWorld::sendRemoteMpiMessage(std::string dstHost, } } -std::shared_ptr<MPIMessage> MpiWorld::recvRemoteMpiMessage(int sendRank, - int recvRank) +MpiMessage MpiWorld::recvRemoteMpiMessage(int sendRank, int recvRank) { std::vector<uint8_t> msg; try { @@ -95,8 +93,12 @@ std::shared_ptr<MPIMessage> MpiWorld::recvRemoteMpiMessage(int sendRank, recvRank); throw e; } - PARSE_MSG(MPIMessage, msg.data(), msg.size()); - return std::make_shared<MPIMessage>(parsedMsg); + + // TODO(mpi-opt): make sure we minimze copies here + MpiMessage parsedMsg; + parseMpiMsg(msg, &parsedMsg); + + return parsedMsg; } std::shared_ptr<MpiMessageBuffer> MpiWorld::getUnackedMessageBuffer( @@ -447,7 +449,7 @@ int MpiWorld::isend(int sendRank, const uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType) + MpiMessageType messageType) { int requestId = (int)faabric::util::generateGid(); iSendRequests.insert(requestId); @@ -462,7 +464,7 @@ int MpiWorld::irecv(int sendRank, uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType) + MpiMessageType messageType) { int requestId = (int)faabric::util::generateGid(); reqIdToRanks.try_emplace(requestId, sendRank, recvRank); @@ -489,7 +491,7 @@ void MpiWorld::send(int sendRank, const uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType) + MpiMessageType messageType) { // Sanity-check input parameters checkRanksRange(sendRank, recvRank); @@ -506,45 +508,39 @@ void MpiWorld::send(int sendRank, // Generate a message ID int msgId = (localMsgCount + 1) % INT32_MAX; - // Create the message - auto m = std::make_shared<MPIMessage>(); - m->set_id(msgId); - m->set_worldid(id); - m->set_sender(sendRank); - m->set_destination(recvRank); - m->set_type(dataType->id); - m->set_count(count); - m->set_messagetype(messageType); - - // Set up message data - bool mustSendData = count > 0 && buffer != nullptr; + MpiMessage msg = { .id = msgId, + .worldId = id, + .sendRank = sendRank, + .recvRank = recvRank, + .typeSize = dataType->size, + .count = count, + .messageType = messageType, + .buffer = (void*)buffer }; // Mock the message sending in tests if (faabric::util::isMockMode()) { - mpiMockedMessages[sendRank].push_back(m); + mpiMockedMessages[sendRank].push_back(msg); return; } // Dispatch the message locally or globally if (isLocal) { - if (mustSendData) { + // Take control over the buffer data if we are gonna move it to + // the in-memory queues for local messaging + if (count > 0 && buffer != nullptr) { void* bufferPtr = faabric::util::malloc(count * dataType->size); std::memcpy(bufferPtr, buffer, count * dataType->size); - m->set_bufferptr((uint64_t)bufferPtr); + msg.buffer = bufferPtr; } SPDLOG_TRACE( "MPI - send {} -> {} ({})", sendRank, recvRank, messageType); - getLocalQueue(sendRank, recvRank)->enqueue(std::move(m)); + getLocalQueue(sendRank, recvRank)->enqueue(msg); } else { - if (mustSendData) { - m->set_buffer(buffer, dataType->size * count); - } - SPDLOG_TRACE( "MPI - send remote {} -> {} ({})", sendRank, recvRank, messageType); - sendRemoteMpiMessage(otherHost, sendRank, recvRank, m); + sendRemoteMpiMessage(otherHost, sendRank, recvRank, msg); } /* 02/05/2022 - The following bit of code fails randomly with a protobuf @@ -572,7 +568,7 @@ void MpiWorld::recv(int sendRank, faabric_datatype_t* dataType, int count, MPI_Status* status, - MPIMessage::MPIMessageType messageType) + MpiMessageType messageType) { // Sanity-check input parameters checkRanksRange(sendRank, recvRank); @@ -582,54 +578,47 @@ void MpiWorld::recv(int sendRank, return; } - // Recv message from underlying transport - std::shared_ptr<MPIMessage> m = recvBatchReturnLast(sendRank, recvRank); + auto msg = recvBatchReturnLast(sendRank, recvRank); - // Do the processing - doRecv(m, buffer, dataType, count, status, messageType); + doRecv(std::move(msg), buffer, dataType, count, status, messageType); } -void MpiWorld::doRecv(std::shared_ptr<MPIMessage>& m, +void MpiWorld::doRecv(const MpiMessage& m, uint8_t* buffer, faabric_datatype_t* dataType, int count, MPI_Status* status, - MPIMessage::MPIMessageType messageType) + MpiMessageType messageType) { // Assert message integrity // Note - this checks won't happen in Release builds - if (m->messagetype() != messageType) { + if (m.messageType != messageType) { SPDLOG_ERROR("Different message types (got: {}, expected: {})", - m->messagetype(), + m.messageType, messageType); } - assert(m->messagetype() == messageType); - assert(m->count() <= count); - - const std::string otherHost = getHostForRank(m->destination()); - bool isLocal = - getHostForRank(m->destination()) == getHostForRank(m->sender()); - - if (m->count() > 0) { - if (isLocal) { - // Make sure we do not overflow the recepient buffer - auto bytesToCopy = std::min<size_t>(m->count() * dataType->size, - count * dataType->size); - std::memcpy(buffer, (void*)m->bufferptr(), bytesToCopy); - faabric::util::free((void*)m->bufferptr()); - } else { - // TODO - avoid copy here - std::move(m->buffer().begin(), m->buffer().end(), buffer); - } + assert(m.messageType == messageType); + assert(m.count <= count); + + // We must copy the data into the application-provided buffer + if (m.count > 0 && m.buffer != nullptr) { + // Make sure we do not overflow the recepient buffer + auto bytesToCopy = + std::min<size_t>(m.count * dataType->size, count * dataType->size); + std::memcpy(buffer, m.buffer, bytesToCopy); + + // This buffer has been malloc-ed either as part of a local `send` + // or as part of a remote `parseMpiMsg` + faabric::util::free((void*)m.buffer); } // Set status values if required if (status != nullptr) { - status->MPI_SOURCE = m->sender(); + status->MPI_SOURCE = m.sendRank; status->MPI_ERROR = MPI_SUCCESS; // Take the message size here as the receive count may be larger - status->bytesSize = m->count() * dataType->size; + status->bytesSize = m.count * dataType->size; // TODO - thread through tag status->MPI_TAG = -1; @@ -667,14 +656,14 @@ void MpiWorld::sendRecv(uint8_t* sendBuffer, recvBuffer, recvDataType, recvCount, - MPIMessage::SENDRECV); + MpiMessageType::SENDRECV); // Then send the message send(myRank, sendRank, sendBuffer, sendDataType, sendCount, - MPIMessage::SENDRECV); + MpiMessageType::SENDRECV); // And wait awaitAsyncRequest(recvId); } @@ -684,7 +673,7 @@ void MpiWorld::broadcast(int sendRank, uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType) + MpiMessageType messageType) { SPDLOG_TRACE("MPI - bcast {} -> {}", sendRank, recvRank); @@ -795,7 +784,7 @@ void MpiWorld::scatter(int sendRank, startPtr, sendType, sendCount, - MPIMessage::SCATTER); + MpiMessageType::SCATTER); } } } else { @@ -806,7 +795,7 @@ void MpiWorld::scatter(int sendRank, recvType, recvCount, nullptr, - MPIMessage::SCATTER); + MpiMessageType::SCATTER); } } @@ -880,7 +869,7 @@ void MpiWorld::gather(int sendRank, recvType, recvCount, nullptr, - MPIMessage::GATHER); + MpiMessageType::GATHER); } } } else { @@ -894,7 +883,7 @@ void MpiWorld::gather(int sendRank, recvType, recvCount * it.second.size(), nullptr, - MPIMessage::GATHER); + MpiMessageType::GATHER); // Copy each received chunk to its offset for (int r = 0; r < it.second.size(); r++) { @@ -924,7 +913,7 @@ void MpiWorld::gather(int sendRank, sendType, sendCount, nullptr, - MPIMessage::GATHER); + MpiMessageType::GATHER); } } @@ -934,7 +923,7 @@ void MpiWorld::gather(int sendRank, rankData.get(), sendType, sendCount * ranksForHost[thisHost].size(), - MPIMessage::GATHER); + MpiMessageType::GATHER); } else if (isLocalLeader && isLocalGather) { // Scenario 3 @@ -943,7 +932,7 @@ void MpiWorld::gather(int sendRank, sendBuffer + sendBufferOffset, sendType, sendCount, - MPIMessage::GATHER); + MpiMessageType::GATHER); } else if (!isLocalLeader && !isLocalGather) { // Scenario 4 send(sendRank, @@ -951,7 +940,7 @@ void MpiWorld::gather(int sendRank, sendBuffer + sendBufferOffset, sendType, sendCount, - MPIMessage::GATHER); + MpiMessageType::GATHER); } else if (!isLocalLeader && isLocalGather) { // Scenario 5 send(sendRank, @@ -959,7 +948,7 @@ void MpiWorld::gather(int sendRank, sendBuffer + sendBufferOffset, sendType, sendCount, - MPIMessage::GATHER); + MpiMessageType::GATHER); } else { SPDLOG_ERROR("Don't know how to gather rank's data."); SPDLOG_ERROR("- sendRank: {}\n- recvRank: {}\n- isGatherReceiver: " @@ -1001,7 +990,7 @@ void MpiWorld::allGather(int rank, // Do a broadcast with a hard-coded root broadcast( - root, rank, recvBuffer, recvType, fullCount, MPIMessage::ALLGATHER); + root, rank, recvBuffer, recvType, fullCount, MpiMessageType::ALLGATHER); } void MpiWorld::awaitAsyncRequest(int requestId) @@ -1033,10 +1022,10 @@ void MpiWorld::awaitAsyncRequest(int requestId) std::list<MpiMessageBuffer::PendingAsyncMpiMessage>::iterator msgIt = umb->getRequestPendingMsg(requestId); - std::shared_ptr<MPIMessage> m; + MpiMessage m; if (msgIt->msg != nullptr) { // This id has already been acknowledged by a recv call, so do the recv - m = msgIt->msg; + m = *(msgIt->msg); } else { // We need to acknowledge all messages not acknowledged from the // begining until us @@ -1094,7 +1083,7 @@ void MpiWorld::reduce(int sendRank, datatype, count, nullptr, - MPIMessage::REDUCE); + MpiMessageType::REDUCE); op_reduce( operation, datatype, count, rankData.get(), recvBuffer); @@ -1108,7 +1097,7 @@ void MpiWorld::reduce(int sendRank, datatype, count, nullptr, - MPIMessage::REDUCE); + MpiMessageType::REDUCE); op_reduce( operation, datatype, count, rankData.get(), recvBuffer); @@ -1138,7 +1127,7 @@ void MpiWorld::reduce(int sendRank, datatype, count, nullptr, - MPIMessage::REDUCE); + MpiMessageType::REDUCE); op_reduce(operation, datatype, @@ -1152,7 +1141,7 @@ void MpiWorld::reduce(int sendRank, sendBufferCopy.get(), datatype, count, - MPIMessage::REDUCE); + MpiMessageType::REDUCE); } else { // Send to the receiver rank send(sendRank, @@ -1160,7 +1149,7 @@ void MpiWorld::reduce(int sendRank, sendBuffer, datatype, count, - MPIMessage::REDUCE); + MpiMessageType::REDUCE); } } else { // If we are neither the receiver of the reduce nor a local leader, we @@ -1175,7 +1164,7 @@ void MpiWorld::reduce(int sendRank, sendBuffer, datatype, count, - MPIMessage::REDUCE); + MpiMessageType::REDUCE); } } @@ -1191,7 +1180,7 @@ void MpiWorld::allReduce(int rank, reduce(rank, 0, sendBuffer, recvBuffer, datatype, count, operation); // Second, 0 broadcasts the result to all ranks - broadcast(0, rank, recvBuffer, datatype, count, MPIMessage::ALLREDUCE); + broadcast(0, rank, recvBuffer, datatype, count, MpiMessageType::ALLREDUCE); } void MpiWorld::op_reduce(faabric_op_t* operation, @@ -1350,14 +1339,14 @@ void MpiWorld::scan(int rank, datatype, count, nullptr, - MPIMessage::SCAN); + MpiMessageType::SCAN); // Reduce with our own value op_reduce(operation, datatype, count, currentAcc.get(), recvBuffer); } // If not the last process, send to the next one if (rank < this->size - 1) { - send(rank, rank + 1, recvBuffer, MPI_INT, count, MPIMessage::SCAN); + send(rank, rank + 1, recvBuffer, MPI_INT, count, MpiMessageType::SCAN); } } @@ -1385,7 +1374,12 @@ void MpiWorld::allToAll(int rank, sendChunk, sendChunk + sendOffset, recvBuffer + rankOffset); } else { // Send message to other rank - send(rank, r, sendChunk, sendType, sendCount, MPIMessage::ALLTOALL); + send(rank, + r, + sendChunk, + sendType, + sendCount, + MpiMessageType::ALLTOALL); } } @@ -1405,7 +1399,7 @@ void MpiWorld::allToAll(int rank, recvType, recvCount, nullptr, - MPIMessage::ALLTOALL); + MpiMessageType::ALLTOALL); } } @@ -1416,15 +1410,17 @@ void MpiWorld::allToAll(int rank, // queues. void MpiWorld::probe(int sendRank, int recvRank, MPI_Status* status) { + throw std::runtime_error("Probe not implemented!"); + /* const std::shared_ptr<InMemoryMpiQueue>& queue = getLocalQueue(sendRank, recvRank); - // 30/12/21 - Peek will throw a runtime error std::shared_ptr<MPIMessage> m = *(queue->peek()); faabric_datatype_t* datatype = getFaabricDatatypeFromId(m->type()); status->bytesSize = m->count() * datatype->size; status->MPI_ERROR = 0; status->MPI_SOURCE = m->sender(); + */ } void MpiWorld::barrier(int thisRank) @@ -1437,17 +1433,17 @@ void MpiWorld::barrier(int thisRank) // Await messages from all others for (int r = 1; r < size; r++) { MPI_Status s{}; - recv(r, 0, nullptr, MPI_INT, 0, &s, MPIMessage::BARRIER_JOIN); + recv(r, 0, nullptr, MPI_INT, 0, &s, MpiMessageType::BARRIER_JOIN); SPDLOG_TRACE("MPI - recv barrier join {}", s.MPI_SOURCE); } } else { // Tell the root that we're waiting SPDLOG_TRACE("MPI - barrier join {}", thisRank); - send(thisRank, 0, nullptr, MPI_INT, 0, MPIMessage::BARRIER_JOIN); + send(thisRank, 0, nullptr, MPI_INT, 0, MpiMessageType::BARRIER_JOIN); } // Rank 0 broadcasts that the barrier is done (the others block here) - broadcast(0, thisRank, nullptr, MPI_INT, 0, MPIMessage::BARRIER_DONE); + broadcast(0, thisRank, nullptr, MPI_INT, 0, MpiMessageType::BARRIER_DONE); SPDLOG_TRACE("MPI - barrier done {}", thisRank); } @@ -1477,9 +1473,10 @@ void MpiWorld::initLocalQueues() } } -std::shared_ptr<MPIMessage> MpiWorld::recvBatchReturnLast(int sendRank, - int recvRank, - int batchSize) +// TODO(mpi-opt): double-check that the fast (no-async) path is fast +MpiMessage MpiWorld::recvBatchReturnLast(int sendRank, + int recvRank, + int batchSize) { std::shared_ptr<MpiMessageBuffer> umb = getUnackedMessageBuffer(sendRank, recvRank); @@ -1499,7 +1496,7 @@ std::shared_ptr<MPIMessage> MpiWorld::recvBatchReturnLast(int sendRank, // Recv message: first we receive all messages for which there is an id // in the unacknowleged buffer but no msg. Note that these messages // (batchSize - 1) were `irecv`-ed before ours. - std::shared_ptr<MPIMessage> ourMsg; + MpiMessage ourMsg; auto msgIt = umb->getFirstNullMsg(); if (isLocal) { // First receive messages that happened before us diff --git a/src/mpi/mpi.proto b/src/mpi/mpi.proto deleted file mode 100644 index 80a690820..000000000 --- a/src/mpi/mpi.proto +++ /dev/null @@ -1,35 +0,0 @@ -syntax = "proto3"; - -package faabric.mpi; - -message MPIMessage { - enum MPIMessageType { - NORMAL = 0; - BARRIER_JOIN = 1; - BARRIER_DONE = 2; - SCATTER = 3; - GATHER = 4; - ALLGATHER = 5; - REDUCE = 6; - SCAN = 7; - ALLREDUCE = 8; - ALLTOALL = 9; - SENDRECV = 10; - BROADCAST = 11; - }; - - MPIMessageType messageType = 1; - - int32 id = 2; - int32 worldId = 3; - int32 sender = 4; - int32 destination = 5; - int32 type = 6; - int32 count = 7; - - // For remote messaging - optional bytes buffer = 8; - - // For local messaging - optional int64 bufferPtr = 9; -} diff --git a/tests/dist/mpi/mpi_native.cpp b/tests/dist/mpi/mpi_native.cpp index d41235940..a499fb357 100644 --- a/tests/dist/mpi/mpi_native.cpp +++ b/tests/dist/mpi/mpi_native.cpp @@ -2,9 +2,9 @@ #include <faabric/executor/ExecutorContext.h> #include <faabric/mpi/MpiContext.h> +#include <faabric/mpi/MpiMessage.h> #include <faabric/mpi/MpiWorld.h> #include <faabric/mpi/mpi.h> -#include <faabric/mpi/mpi.pb.h> #include <faabric/scheduler/FunctionCallClient.h> #include <faabric/scheduler/Scheduler.h> #include <faabric/snapshot/SnapshotClient.h> @@ -126,7 +126,7 @@ int MPI_Send(const void* buf, (uint8_t*)buf, datatype, count, - MPIMessage::NORMAL); + MpiMessageType::NORMAL); return MPI_SUCCESS; } @@ -159,7 +159,7 @@ int MPI_Recv(void* buf, datatype, count, status, - MPIMessage::NORMAL); + MpiMessageType::NORMAL); return MPI_SUCCESS; } @@ -245,7 +245,7 @@ int MPI_Bcast(void* buffer, int rank = executingContext.getRank(); world.broadcast( - root, rank, (uint8_t*)buffer, datatype, count, MPIMessage::BROADCAST); + root, rank, (uint8_t*)buffer, datatype, count, MpiMessageType::BROADCAST); return MPI_SUCCESS; } diff --git a/tests/test/mpi/test_mpi_message.cpp b/tests/test/mpi/test_mpi_message.cpp new file mode 100644 index 000000000..9c79f8d3b --- /dev/null +++ b/tests/test/mpi/test_mpi_message.cpp @@ -0,0 +1,123 @@ +#include <catch2/catch.hpp> + +#include <faabric/mpi/MpiMessage.h> +#include <faabric/util/memory.h> + +#include <cstring> + +using namespace faabric::mpi; + +namespace tests { + +bool areMpiMsgEqual(const MpiMessage& msgA, const MpiMessage& msgB) +{ + auto sizeA = msgSize(msgA); + auto sizeB = msgSize(msgB); + + if (sizeA != sizeB) { + return false; + } + + // First, compare the message body (excluding the pointer, which we + // know is at the end) + if (std::memcmp(&msgA, &msgB, sizeof(MpiMessage) - sizeof(void*)) != 0) { + return false; + } + + // Check that if one buffer points to null, so must do the other + if (msgA.buffer == nullptr || msgB.buffer == nullptr) { + return msgA.buffer == msgB.buffer; + } + + // If none points to null, they must point to the same data + auto payloadSizeA = payloadSize(msgA); + auto payloadSizeB = payloadSize(msgB); + // Assert, as this should pass given the previous comparisons + assert(payloadSizeA == payloadSizeB); + + return std::memcmp(msgA.buffer, msgB.buffer, payloadSizeA) == 0; +} + +TEST_CASE("Test getting a message size", "[mpi]") +{ + MpiMessage msg = { .id = 1, + .worldId = 3, + .sendRank = 3, + .recvRank = 7, + .typeSize = 1, + .count = 3, + .messageType = MpiMessageType::NORMAL }; + + size_t expectedMsgSize = 0; + size_t expectedPayloadSize = 0; + + SECTION("Empty message") + { + msg.buffer = nullptr; + msg.count = 0; + expectedMsgSize = sizeof(MpiMessage); + expectedPayloadSize = 0; + } + + SECTION("Non-empty message") + { + std::vector<int> nums = { 1, 2, 3, 4, 5, 6, 6 }; + msg.count = nums.size(); + msg.typeSize = sizeof(int); + msg.buffer = faabric::util::malloc(msg.count * msg.typeSize); + std::memcpy(msg.buffer, nums.data(), nums.size() * sizeof(int)); + + expectedPayloadSize = sizeof(int) * nums.size(); + expectedMsgSize = sizeof(MpiMessage) + expectedPayloadSize; + } + + REQUIRE(expectedMsgSize == msgSize(msg)); + REQUIRE(expectedPayloadSize == payloadSize(msg)); + + if (msg.buffer != nullptr) { + faabric::util::free(msg.buffer); + } +} + +TEST_CASE("Test (de)serialising an MPI message", "[mpi]") +{ + MpiMessage msg = { .id = 1, + .worldId = 3, + .sendRank = 3, + .recvRank = 7, + .typeSize = 1, + .count = 3, + .messageType = MpiMessageType::NORMAL }; + + SECTION("Empty message") + { + msg.count = 0; + msg.buffer = nullptr; + } + + SECTION("Non-empty message") + { + std::vector<int> nums = { 1, 2, 3, 4, 5, 6, 6 }; + msg.count = nums.size(); + msg.typeSize = sizeof(int); + msg.buffer = faabric::util::malloc(msg.count * msg.typeSize); + std::memcpy(msg.buffer, nums.data(), nums.size() * sizeof(int)); + } + + // Serialise and de-serialise + std::vector<uint8_t> buffer(msgSize(msg)); + serializeMpiMsg(buffer, msg); + + MpiMessage parsedMsg; + parseMpiMsg(buffer, &parsedMsg); + + REQUIRE(areMpiMsgEqual(msg, parsedMsg)); + + if (msg.buffer != nullptr) { + faabric::util::free(msg.buffer); + } + if (parsedMsg.buffer != nullptr) { + faabric::util::free(parsedMsg.buffer); + } +} +} diff --git a/tests/test/mpi/test_mpi_message_buffer.cpp b/tests/test/mpi/test_mpi_message_buffer.cpp index 1674172fd..710a3c259 100644 --- a/tests/test/mpi/test_mpi_message_buffer.cpp +++ b/tests/test/mpi/test_mpi_message_buffer.cpp @@ -21,7 +21,7 @@ MpiMessageBuffer::PendingAsyncMpiMessage genRandomArguments( pendingMsg.requestId = requestId; if (!nullMsg) { - pendingMsg.msg = std::make_shared<MPIMessage>(); + pendingMsg.msg = std::make_shared<MpiMessage>(); } return pendingMsg; diff --git a/tests/test/mpi/test_mpi_world.cpp b/tests/test/mpi/test_mpi_world.cpp index 2c0030b5f..8094cae7e 100644 --- a/tests/test/mpi/test_mpi_world.cpp +++ b/tests/test/mpi/test_mpi_world.cpp @@ -212,23 +212,22 @@ TEST_CASE_METHOD(MpiBaseTestFixture, "Test local barrier", "[mpi]") world.destroy(); } -void checkMessage(MPIMessage& actualMessage, +void checkMessage(MpiMessage& actualMessage, int worldId, int senderRank, int destRank, const std::vector<int>& data) { // Check the message contents - REQUIRE(actualMessage.worldid() == worldId); - REQUIRE(actualMessage.count() == data.size()); - REQUIRE(actualMessage.destination() == destRank); - REQUIRE(actualMessage.sender() == senderRank); - REQUIRE(actualMessage.type() == FAABRIC_INT); + REQUIRE(actualMessage.worldId == worldId); + REQUIRE(actualMessage.count == data.size()); + REQUIRE(actualMessage.recvRank == destRank); + REQUIRE(actualMessage.sendRank == senderRank); + REQUIRE(actualMessage.typeSize == FAABRIC_INT); // Check data - const auto* rawInts = - reinterpret_cast<const int*>(actualMessage.buffer().c_str()); - size_t nInts = actualMessage.buffer().size() / sizeof(int); + const auto* rawInts = reinterpret_cast<const int*>(actualMessage.buffer); + size_t nInts = payloadSize(actualMessage) / sizeof(int); std::vector<int> actualData(rawInts, rawInts + nInts); REQUIRE(actualData == data); } @@ -396,10 +395,10 @@ TEST_CASE_METHOD(MpiTestFixture, "Test send/recv message with no data", "[mpi]") SECTION("Check on queue") { // Check message content - MPIMessage actualMessage = - *(world.getLocalQueue(rankA1, rankA2)->dequeue()); - REQUIRE(actualMessage.count() == 0); - REQUIRE(actualMessage.type() == FAABRIC_INT); + MpiMessage actualMessage = + world.getLocalQueue(rankA1, rankA2)->dequeue(); + REQUIRE(actualMessage.count == 0); + REQUIRE(actualMessage.typeSize == FAABRIC_INT); } SECTION("Check receiving with null ptr") @@ -502,7 +501,7 @@ TEST_CASE_METHOD(MpiTestFixture, "Test collective messaging locally", "[mpi]") BYTES(messageData.data()), MPI_INT, messageData.size(), - MPIMessage::BROADCAST); + MpiMessageType::BROADCAST); // Recv on all non-root ranks for (int rank = 0; rank < worldSize; rank++) { @@ -515,7 +514,7 @@ TEST_CASE_METHOD(MpiTestFixture, "Test collective messaging locally", "[mpi]") BYTES(actual.data()), MPI_INT, 3, - MPIMessage::BROADCAST); + MpiMessageType::BROADCAST); REQUIRE(actual == messageData); } } diff --git a/tests/test/mpi/test_multiple_mpi_worlds.cpp b/tests/test/mpi/test_multiple_mpi_worlds.cpp index 735556f6e..a6e74f0b5 100644 --- a/tests/test/mpi/test_multiple_mpi_worlds.cpp +++ b/tests/test/mpi/test_multiple_mpi_worlds.cpp @@ -164,7 +164,7 @@ TEST_CASE_METHOD(MultiWorldMpiTestFixture, REQUIRE(worldA.getLocalQueueSize(rankA2, 0) == 0); const std::shared_ptr<InMemoryMpiQueue>& queueA2 = worldA.getLocalQueue(rankA1, rankA2); - MPIMessage actualMessage = *(queueA2->dequeue()); + MpiMessage actualMessage = queueA2->dequeue(); // checkMessage(actualMessage, worldId, rankA1, rankA2, messageData); // Check for world B @@ -174,7 +174,7 @@ TEST_CASE_METHOD(MultiWorldMpiTestFixture, REQUIRE(worldB.getLocalQueueSize(rankA2, 0) == 0); const std::shared_ptr<InMemoryMpiQueue>& queueA2B = worldB.getLocalQueue(rankA1, rankA2); - actualMessage = *(queueA2B->dequeue()); + actualMessage = queueA2B->dequeue(); // checkMessage(actualMessage, worldId, rankA1, rankA2, messageData); } diff --git a/tests/test/mpi/test_remote_mpi_worlds.cpp b/tests/test/mpi/test_remote_mpi_worlds.cpp index 1e56b48b1..54662929f 100644 --- a/tests/test/mpi/test_remote_mpi_worlds.cpp +++ b/tests/test/mpi/test_remote_mpi_worlds.cpp @@ -21,12 +21,11 @@ using namespace faabric::mpi; using namespace faabric::scheduler; namespace tests { -std::set<int> getReceiversFromMessages( - std::vector<std::shared_ptr<MPIMessage>> msgs) +std::set<int> getReceiversFromMessages(std::vector<MpiMessage> msgs) { std::set<int> receivers; for (const auto& msg : msgs) { - receivers.insert(msg->destination()); + receivers.insert(msg.recvRank); } return receivers; @@ -108,14 +107,14 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, BYTES(messageData.data()), MPI_INT, messageData.size(), - MPIMessage::BROADCAST); + MpiMessageType::BROADCAST); } else { otherWorld.broadcast(sendRank, recvRank, BYTES(messageData.data()), MPI_INT, messageData.size(), - MPIMessage::BROADCAST); + MpiMessageType::BROADCAST); } auto msgs = getMpiMockedMessages(recvRank); REQUIRE(msgs.size() == expectedNumMsg); @@ -219,12 +218,11 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, thisWorld.destroy(); } -std::set<int> getMsgCountsFromMessages( - std::vector<std::shared_ptr<MPIMessage>> msgs) +std::set<int> getMsgCountsFromMessages(std::vector<MpiMessage> msgs) { std::set<int> counts; for (const auto& msg : msgs) { - counts.insert(msg->count()); + counts.insert(msg.count); } return counts; From 86a8e7647807c062bd7170440aad6746151f92a1 Mon Sep 17 00:00:00 2001 From: Carlos Segarra <carlos@carlossegarra.com> Date: Wed, 20 Mar 2024 16:23:49 +0000 Subject: [PATCH 2/2] mpi: make struct 8-byte aligned --- include/faabric/mpi/MpiMessage.h | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/include/faabric/mpi/MpiMessage.h b/include/faabric/mpi/MpiMessage.h index 7c85fde48..853f3dce9 100644 --- a/include/faabric/mpi/MpiMessage.h +++ b/include/faabric/mpi/MpiMessage.h @@ -21,6 +21,18 @@ enum MpiMessageType : int32_t BROADCAST = 11, }; +/* Simple fixed-size C-struct to capture the state of an MPI message moving + * through Faabric. + * + * We require fixed-size, and no unique pointers to be able to use + * high-throughput in-memory ring-buffers to send the messages around. + * This also means that we manually malloc/free the data pointer. The message + * size is: + * 7 * int32_t = 7 * 4 bytes = 28 bytes + * 1 * int32_t (padding) = 4 bytes + * 1 * void* = 1 * 8 bytes = 8 bytes + * total = 40 bytes = 5 * 8 so the struct is 8 byte-aligned + */ struct MpiMessage { int32_t id; @@ -30,8 +42,10 @@ struct MpiMessage int32_t typeSize; int32_t count; MpiMessageType messageType; + int32_t __make_8_byte_aligned; void* buffer; }; +static_assert((sizeof(MpiMessage) % 8) == 0, "MPI message must be 8-aligned!"); inline size_t payloadSize(const MpiMessage& msg) {