Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mpi: move MpiMessage from protobuf to c-struct #379

Merged
merged 2 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions include/faabric/mpi/MpiMessage.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#pragma once

#include <cstdint>
#include <vector>

namespace faabric::mpi {

enum MpiMessageType : int32_t
{
NORMAL = 0,
BARRIER_JOIN = 1,
BARRIER_DONE = 2,
SCATTER = 3,
GATHER = 4,
ALLGATHER = 5,
REDUCE = 6,
SCAN = 7,
ALLREDUCE = 8,
ALLTOALL = 9,
SENDRECV = 10,
BROADCAST = 11,
};

/* Simple fixed-size C-struct to capture the state of an MPI message moving
* through Faabric.
*
* We require fixed-size, and no unique pointers to be able to use
* high-throughput in-memory ring-buffers to send the messages around.
* This also means that we manually malloc/free the data pointer. The message
* size is:
* 7 * int32_t = 7 * 4 bytes = 28 bytes
* 1 * int32_t (padding) = 4 bytes
* 1 * void* = 1 * 8 bytes = 8 bytes
* total = 40 bytes = 5 * 8 so the struct is 8 byte-aligned
*/
struct MpiMessage
{
int32_t id;
int32_t worldId;
int32_t sendRank;
int32_t recvRank;
int32_t typeSize;
int32_t count;
MpiMessageType messageType;
int32_t __make_8_byte_aligned;
void* buffer;
};
static_assert((sizeof(MpiMessage) % 8) == 0, "MPI message must be 8-aligned!");

inline size_t payloadSize(const MpiMessage& msg)
{
return msg.typeSize * msg.count;
}

inline size_t msgSize(const MpiMessage& msg)
{
return sizeof(MpiMessage) + payloadSize(msg);
}

void serializeMpiMsg(std::vector<uint8_t>& buffer, const MpiMessage& msg);

void parseMpiMsg(const std::vector<uint8_t>& bytes, MpiMessage* msg);
}
12 changes: 8 additions & 4 deletions include/faabric/mpi/MpiMessageBuffer.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#include <faabric/mpi/MpiMessage.h>
#include <faabric/mpi/mpi.h>
#include <faabric/mpi/mpi.pb.h>

#include <iterator>
#include <list>
#include <memory>

namespace faabric::mpi {
/* The MPI message buffer (MMB) keeps track of the asyncrhonous
Expand All @@ -25,17 +26,20 @@ class MpiMessageBuffer
{
public:
int requestId = -1;
std::shared_ptr<MPIMessage> msg = nullptr;
std::shared_ptr<MpiMessage> msg = nullptr;
int sendRank = -1;
int recvRank = -1;
uint8_t* buffer = nullptr;
faabric_datatype_t* dataType = nullptr;
int count = -1;
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL;
MpiMessageType messageType = MpiMessageType::NORMAL;

bool isAcknowledged() { return msg != nullptr; }

void acknowledge(std::shared_ptr<MPIMessage> msgIn) { msg = msgIn; }
void acknowledge(const MpiMessage& msgIn)
{
msg = std::make_shared<MpiMessage>(msgIn);
}
};

/* Interface to query the buffer size */
Expand Down
40 changes: 23 additions & 17 deletions include/faabric/mpi/MpiWorld.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#pragma once

#include <faabric/mpi/MpiMessage.h>
#include <faabric/mpi/MpiMessageBuffer.h>
#include <faabric/mpi/mpi.h>
#include <faabric/mpi/mpi.pb.h>
#include <faabric/proto/faabric.pb.h>
#include <faabric/scheduler/InMemoryMessageQueue.h>
#include <faabric/transport/PointToPointBroker.h>
Expand All @@ -26,10 +26,9 @@ namespace faabric::mpi {
// -----------------------------------
// MPITOPTP - mocking at the MPI level won't be needed when using the PTP broker
// as the broker already has mocking capabilities
std::vector<std::shared_ptr<MPIMessage>> getMpiMockedMessages(int sendRank);
std::vector<MpiMessage> getMpiMockedMessages(int sendRank);

typedef faabric::util::FixedCapacityQueue<std::shared_ptr<MPIMessage>>
InMemoryMpiQueue;
typedef faabric::util::FixedCapacityQueue<MpiMessage> InMemoryMpiQueue;

class MpiWorld
{
Expand Down Expand Up @@ -73,36 +72,36 @@ class MpiWorld
const uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
MpiMessageType messageType = MpiMessageType::NORMAL);

int isend(int sendRank,
int recvRank,
const uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
MpiMessageType messageType = MpiMessageType::NORMAL);

void broadcast(int rootRank,
int thisRank,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
MpiMessageType messageType = MpiMessageType::NORMAL);

void recv(int sendRank,
int recvRank,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
MPI_Status* status,
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
MpiMessageType messageType = MpiMessageType::NORMAL);

int irecv(int sendRank,
int recvRank,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
MpiMessageType messageType = MpiMessageType::NORMAL);

void awaitAsyncRequest(int requestId);

Expand Down Expand Up @@ -240,29 +239,36 @@ class MpiWorld
void sendRemoteMpiMessage(std::string dstHost,
int sendRank,
int recvRank,
const std::shared_ptr<MPIMessage>& msg);
const MpiMessage& msg);

std::shared_ptr<MPIMessage> recvRemoteMpiMessage(int sendRank,
int recvRank);
MpiMessage recvRemoteMpiMessage(int sendRank, int recvRank);

// Support for asyncrhonous communications
std::shared_ptr<MpiMessageBuffer> getUnackedMessageBuffer(int sendRank,
int recvRank);

std::shared_ptr<MPIMessage> recvBatchReturnLast(int sendRank,
int recvRank,
int batchSize = 0);
MpiMessage recvBatchReturnLast(int sendRank,
int recvRank,
int batchSize = 0);

/* Helper methods */

void checkRanksRange(int sendRank, int recvRank);

// Abstraction of the bulk of the recv work, shared among various functions
void doRecv(std::shared_ptr<MPIMessage>& m,
void doRecv(const MpiMessage& m,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
MPI_Status* status,
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
MpiMessageType messageType = MpiMessageType::NORMAL);

// Abstraction of the bulk of the recv work, shared among various functions
void doRecv(std::unique_ptr<MpiMessage> m,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
MPI_Status* status,
MpiMessageType messageType = MpiMessageType::NORMAL);
};
}
22 changes: 1 addition & 21 deletions src/mpi/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,32 +38,12 @@ endif()
# -----------------------------------------------

if (NOT ("${CMAKE_PROJECT_NAME}" STREQUAL "faabricmpi"))
# Generate protobuf headers
set(MPI_PB_HEADER_COPIED "${FAABRIC_INCLUDE_DIR}/faabric/mpi/mpi.pb.h")

protobuf_generate_cpp(MPI_PB_SRC MPI_PB_HEADER mpi.proto)

# Copy the generated headers into place
add_custom_command(
OUTPUT "${MPI_PB_HEADER_COPIED}"
DEPENDS "${MPI_PB_HEADER}"
COMMAND ${CMAKE_COMMAND}
ARGS -E copy ${MPI_PB_HEADER} ${FAABRIC_INCLUDE_DIR}/faabric/mpi/
)

add_custom_target(
mpi_pbh_copied
DEPENDS ${MPI_PB_HEADER_COPIED}
)

add_dependencies(faabric_common_dependencies mpi_pbh_copied)

faabric_lib(mpi
MpiContext.cpp
MpiMessage.cpp
MpiMessageBuffer.cpp
MpiWorld.cpp
MpiWorldRegistry.cpp
${MPI_PB_SRC}
)

target_link_libraries(mpi PRIVATE
Expand Down
36 changes: 36 additions & 0 deletions src/mpi/MpiMessage.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include <faabric/mpi/MpiMessage.h>
#include <faabric/util/memory.h>

#include <cassert>
#include <cstdint>
#include <cstring>

namespace faabric::mpi {

void parseMpiMsg(const std::vector<uint8_t>& bytes, MpiMessage* msg)
{
assert(msg != nullptr);
assert(bytes.size() >= sizeof(MpiMessage));
std::memcpy(msg, bytes.data(), sizeof(MpiMessage));
size_t thisPayloadSize = bytes.size() - sizeof(MpiMessage);
assert(thisPayloadSize == payloadSize(*msg));

if (thisPayloadSize == 0) {
msg->buffer = nullptr;
return;
}

msg->buffer = faabric::util::malloc(thisPayloadSize);
std::memcpy(
msg->buffer, bytes.data() + sizeof(MpiMessage), thisPayloadSize);
}

void serializeMpiMsg(std::vector<uint8_t>& buffer, const MpiMessage& msg)
{
std::memcpy(buffer.data(), &msg, sizeof(MpiMessage));
size_t payloadSz = payloadSize(msg);
if (payloadSz > 0 && msg.buffer != nullptr) {
std::memcpy(buffer.data() + sizeof(MpiMessage), msg.buffer, payloadSz);
}
}
}
Loading
Loading