Skip to content

Commit

Permalink
mpi: #385 and #379
Browse files Browse the repository at this point in the history
  • Loading branch information
csegarragonz committed Feb 28, 2024
1 parent 0581ef2 commit b5b8123
Show file tree
Hide file tree
Showing 15 changed files with 378 additions and 243 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ jobs:
if: github.event.pull_request.draft == false
needs: [conan-cache]
runs-on: ubuntu-latest
timeout-minutes: 20
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -139,6 +140,7 @@ jobs:
if: github.event.pull_request.draft == false
needs: [conan-cache]
runs-on: ubuntu-latest
timeout-minutes: 20
env:
CONAN_CACHE_MOUNT_SOURCE: ~/.conan/
steps:
Expand Down
49 changes: 49 additions & 0 deletions include/faabric/mpi/MpiMessage.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#pragma once

#include <cstdint>
#include <vector>

namespace faabric::mpi {

enum MpiMessageType : int32_t
{
NORMAL = 0,
BARRIER_JOIN = 1,
BARRIER_DONE = 2,
SCATTER = 3,
GATHER = 4,
ALLGATHER = 5,
REDUCE = 6,
SCAN = 7,
ALLREDUCE = 8,
ALLTOALL = 9,
SENDRECV = 10,
BROADCAST = 11,
};

struct MpiMessage
{
int32_t id;
int32_t worldId;
int32_t sendRank;
int32_t recvRank;
int32_t typeSize;
int32_t count;
MpiMessageType messageType;
void* buffer;
};

inline size_t payloadSize(const MpiMessage& msg)
{
return msg.typeSize * msg.count;
}

inline size_t msgSize(const MpiMessage& msg)
{
return sizeof(MpiMessage) + payloadSize(msg);
}

void serializeMpiMsg(std::vector<uint8_t>& buffer, const MpiMessage& msg);

void parseMpiMsg(const std::vector<uint8_t>& bytes, MpiMessage* msg);
}
12 changes: 8 additions & 4 deletions include/faabric/mpi/MpiMessageBuffer.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#include <faabric/mpi/MpiMessage.h>
#include <faabric/mpi/mpi.h>
#include <faabric/mpi/mpi.pb.h>

#include <iterator>
#include <list>
#include <memory>

namespace faabric::mpi {
/* The MPI message buffer (MMB) keeps track of the asyncrhonous
Expand All @@ -25,17 +26,20 @@ class MpiMessageBuffer
{
public:
int requestId = -1;
std::shared_ptr<MPIMessage> msg = nullptr;
std::shared_ptr<MpiMessage> msg = nullptr;
int sendRank = -1;
int recvRank = -1;
uint8_t* buffer = nullptr;
faabric_datatype_t* dataType = nullptr;
int count = -1;
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL;
MpiMessageType messageType = MpiMessageType::NORMAL;

bool isAcknowledged() { return msg != nullptr; }

void acknowledge(std::shared_ptr<MPIMessage> msgIn) { msg = msgIn; }
void acknowledge(const MpiMessage& msgIn)
{
msg = std::make_shared<MpiMessage>(msgIn);
}
};

/* Interface to query the buffer size */
Expand Down
42 changes: 23 additions & 19 deletions include/faabric/mpi/MpiWorld.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#pragma once

#include <faabric/mpi/MpiMessage.h>
#include <faabric/mpi/MpiMessageBuffer.h>
#include <faabric/mpi/mpi.h>
#include <faabric/mpi/mpi.pb.h>
#include <faabric/proto/faabric.pb.h>
#include <faabric/scheduler/InMemoryMessageQueue.h>
#include <faabric/transport/PointToPointBroker.h>
Expand All @@ -26,10 +26,9 @@ namespace faabric::mpi {
// -----------------------------------
// MPITOPTP - mocking at the MPI level won't be needed when using the PTP broker
// as the broker already has mocking capabilities
std::vector<std::shared_ptr<MPIMessage>> getMpiMockedMessages(int sendRank);
std::vector<MpiMessage> getMpiMockedMessages(int sendRank);

typedef faabric::util::SpinLockQueue<std::shared_ptr<MPIMessage>>
InMemoryMpiQueue;
typedef faabric::util::SpinLockQueue<MpiMessage> InMemoryMpiQueue;

class MpiWorld
{
Expand Down Expand Up @@ -73,36 +72,36 @@ class MpiWorld
const uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
MpiMessageType messageType = MpiMessageType::NORMAL);

int isend(int sendRank,
int recvRank,
const uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
MpiMessageType messageType = MpiMessageType::NORMAL);

void broadcast(int rootRank,
int thisRank,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
MpiMessageType messageType = MpiMessageType::NORMAL);

void recv(int sendRank,
int recvRank,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
MPI_Status* status,
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
MpiMessageType messageType = MpiMessageType::NORMAL);

int irecv(int sendRank,
int recvRank,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
MpiMessageType messageType = MpiMessageType::NORMAL);

void awaitAsyncRequest(int requestId);

Expand Down Expand Up @@ -185,8 +184,6 @@ class MpiWorld

std::shared_ptr<InMemoryMpiQueue> getLocalQueue(int sendRank, int recvRank);

long getLocalQueueSize(int sendRank, int recvRank);

void overrideHost(const std::string& newHost);

double getWTime();
Expand Down Expand Up @@ -240,29 +237,36 @@ class MpiWorld
void sendRemoteMpiMessage(std::string dstHost,
int sendRank,
int recvRank,
const std::shared_ptr<MPIMessage>& msg);
const MpiMessage& msg);

std::shared_ptr<MPIMessage> recvRemoteMpiMessage(int sendRank,
int recvRank);
MpiMessage recvRemoteMpiMessage(int sendRank, int recvRank);

// Support for asyncrhonous communications
std::shared_ptr<MpiMessageBuffer> getUnackedMessageBuffer(int sendRank,
int recvRank);

std::shared_ptr<MPIMessage> recvBatchReturnLast(int sendRank,
int recvRank,
int batchSize = 0);
MpiMessage recvBatchReturnLast(int sendRank,
int recvRank,
int batchSize = 0);

/* Helper methods */

void checkRanksRange(int sendRank, int recvRank);

// Abstraction of the bulk of the recv work, shared among various functions
void doRecv(std::shared_ptr<MPIMessage>& m,
void doRecv(const MpiMessage& m,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
MPI_Status* status,
MpiMessageType messageType = MpiMessageType::NORMAL);

// Abstraction of the bulk of the recv work, shared among various functions
void doRecv(std::unique_ptr<MpiMessage> m,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
MPI_Status* status,
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
MpiMessageType messageType = MpiMessageType::NORMAL);
};
}
24 changes: 17 additions & 7 deletions include/faabric/util/queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,24 +222,34 @@ template<typename T>
class SpinLockQueue
{
public:
void enqueue(T& value, long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS) {
while (!mq.push(value)) { ; };
void enqueue(T& value, long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS)
{
while (!mq.push(value)) {
;
};
}

T dequeue(long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS) {
T dequeue(long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS)
{
T value;

while (!mq.pop(value)) { ; }
while (!mq.pop(value)) {
;
}

return value;
}

long size() {
long size()
{
throw std::runtime_error("Size for fast queue unimplemented!");
}

void drain() {
while (mq.pop()) { ; }
void drain()
{
while (mq.pop()) {
;
}
}

void reset() { ; }
Expand Down
22 changes: 1 addition & 21 deletions src/mpi/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,32 +38,12 @@ endif()
# -----------------------------------------------

if (NOT ("${CMAKE_PROJECT_NAME}" STREQUAL "faabricmpi"))
# Generate protobuf headers
set(MPI_PB_HEADER_COPIED "${FAABRIC_INCLUDE_DIR}/faabric/mpi/mpi.pb.h")

protobuf_generate_cpp(MPI_PB_SRC MPI_PB_HEADER mpi.proto)

# Copy the generated headers into place
add_custom_command(
OUTPUT "${MPI_PB_HEADER_COPIED}"
DEPENDS "${MPI_PB_HEADER}"
COMMAND ${CMAKE_COMMAND}
ARGS -E copy ${MPI_PB_HEADER} ${FAABRIC_INCLUDE_DIR}/faabric/mpi/
)

add_custom_target(
mpi_pbh_copied
DEPENDS ${MPI_PB_HEADER_COPIED}
)

add_dependencies(faabric_common_dependencies mpi_pbh_copied)

faabric_lib(mpi
MpiContext.cpp
MpiMessage.cpp
MpiMessageBuffer.cpp
MpiWorld.cpp
MpiWorldRegistry.cpp
${MPI_PB_SRC}
)

target_link_libraries(mpi PRIVATE
Expand Down
36 changes: 36 additions & 0 deletions src/mpi/MpiMessage.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include <faabric/mpi/MpiMessage.h>
#include <faabric/util/memory.h>

#include <cassert>
#include <cstdint>
#include <cstring>

namespace faabric::mpi {

void parseMpiMsg(const std::vector<uint8_t>& bytes, MpiMessage* msg)
{
assert(msg != nullptr);
assert(bytes.size() >= sizeof(MpiMessage));
std::memcpy(msg, bytes.data(), sizeof(MpiMessage));
size_t thisPayloadSize = bytes.size() - sizeof(MpiMessage);
assert(thisPayloadSize == payloadSize(*msg));

if (thisPayloadSize == 0) {
msg->buffer = nullptr;
return;
}

msg->buffer = faabric::util::malloc(thisPayloadSize);
std::memcpy(
msg->buffer, bytes.data() + sizeof(MpiMessage), thisPayloadSize);
}

void serializeMpiMsg(std::vector<uint8_t>& buffer, const MpiMessage& msg)
{
std::memcpy(buffer.data(), &msg, sizeof(MpiMessage));
size_t payloadSz = payloadSize(msg);
if (payloadSz > 0 && msg.buffer != nullptr) {
std::memcpy(buffer.data() + sizeof(MpiMessage), msg.buffer, payloadSz);
}
}
}
Loading

0 comments on commit b5b8123

Please sign in to comment.