From 26568f78c368256eae5aa3cb74c94bc737bbd85a Mon Sep 17 00:00:00 2001
From: Carlos Segarra <carlos@carlossegarra.com>
Date: Thu, 22 Feb 2024 17:51:27 +0000
Subject: [PATCH 1/2] mpi: move MpiMessage from protobuf to struct

---
 include/faabric/mpi/MpiMessage.h            |  49 +++++
 include/faabric/mpi/MpiMessageBuffer.h      |  12 +-
 include/faabric/mpi/MpiWorld.h              |  40 ++--
 src/mpi/CMakeLists.txt                      |  22 +--
 src/mpi/MpiMessage.cpp                      |  36 ++++
 src/mpi/MpiWorld.cpp                        | 197 ++++++++++----------
 src/mpi/mpi.proto                           |  35 ----
 tests/dist/mpi/mpi_native.cpp               |   8 +-
 tests/test/mpi/test_mpi_message.cpp         | 123 ++++++++++++
 tests/test/mpi/test_mpi_message_buffer.cpp  |   2 +-
 tests/test/mpi/test_mpi_world.cpp           |  29 ++-
 tests/test/mpi/test_multiple_mpi_worlds.cpp |   4 +-
 tests/test/mpi/test_remote_mpi_worlds.cpp   |  14 +-
 13 files changed, 364 insertions(+), 207 deletions(-)
 create mode 100644 include/faabric/mpi/MpiMessage.h
 create mode 100644 src/mpi/MpiMessage.cpp
 delete mode 100644 src/mpi/mpi.proto
 create mode 100644 tests/test/mpi/test_mpi_message.cpp

diff --git a/include/faabric/mpi/MpiMessage.h b/include/faabric/mpi/MpiMessage.h
new file mode 100644
index 000000000..7c85fde48
--- /dev/null
+++ b/include/faabric/mpi/MpiMessage.h
@@ -0,0 +1,49 @@
+#pragma once
+
+#include <cstdint>
+#include <vector>
+
+namespace faabric::mpi {
+
+enum MpiMessageType : int32_t
+{
+    NORMAL = 0,
+    BARRIER_JOIN = 1,
+    BARRIER_DONE = 2,
+    SCATTER = 3,
+    GATHER = 4,
+    ALLGATHER = 5,
+    REDUCE = 6,
+    SCAN = 7,
+    ALLREDUCE = 8,
+    ALLTOALL = 9,
+    SENDRECV = 10,
+    BROADCAST = 11,
+};
+
+struct MpiMessage
+{
+    int32_t id;
+    int32_t worldId;
+    int32_t sendRank;
+    int32_t recvRank;
+    int32_t typeSize;
+    int32_t count;
+    MpiMessageType messageType;
+    void* buffer;
+};
+
+inline size_t payloadSize(const MpiMessage& msg)
+{
+    return msg.typeSize * msg.count;
+}
+
+inline size_t msgSize(const MpiMessage& msg)
+{
+    return sizeof(MpiMessage) + payloadSize(msg);
+}
+
+void serializeMpiMsg(std::vector<uint8_t>& buffer, const MpiMessage& msg);
+
+void parseMpiMsg(const std::vector<uint8_t>& bytes, MpiMessage* msg);
+}
diff --git a/include/faabric/mpi/MpiMessageBuffer.h b/include/faabric/mpi/MpiMessageBuffer.h
index 9fc67b644..c36f89887 100644
--- a/include/faabric/mpi/MpiMessageBuffer.h
+++ b/include/faabric/mpi/MpiMessageBuffer.h
@@ -1,8 +1,9 @@
+#include <faabric/mpi/MpiMessage.h>
 #include <faabric/mpi/mpi.h>
-#include <faabric/mpi/mpi.pb.h>
 
 #include <iterator>
 #include <list>
+#include <memory>
 
 namespace faabric::mpi {
 /* The MPI message buffer (MMB) keeps track of the asyncrhonous
@@ -25,17 +26,20 @@ class MpiMessageBuffer
     {
       public:
         int requestId = -1;
-        std::shared_ptr<MPIMessage> msg = nullptr;
+        std::shared_ptr<MpiMessage> msg = nullptr;
         int sendRank = -1;
         int recvRank = -1;
         uint8_t* buffer = nullptr;
         faabric_datatype_t* dataType = nullptr;
         int count = -1;
-        MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL;
+        MpiMessageType messageType = MpiMessageType::NORMAL;
 
         bool isAcknowledged() { return msg != nullptr; }
 
-        void acknowledge(std::shared_ptr<MPIMessage> msgIn) { msg = msgIn; }
+        void acknowledge(const MpiMessage& msgIn)
+        {
+            msg = std::make_shared<MpiMessage>(msgIn);
+        }
     };
 
     /* Interface to query the buffer size */
diff --git a/include/faabric/mpi/MpiWorld.h b/include/faabric/mpi/MpiWorld.h
index adee54137..8f9cb918c 100644
--- a/include/faabric/mpi/MpiWorld.h
+++ b/include/faabric/mpi/MpiWorld.h
@@ -1,8 +1,8 @@
 #pragma once
 
+#include <faabric/mpi/MpiMessage.h>
 #include <faabric/mpi/MpiMessageBuffer.h>
 #include <faabric/mpi/mpi.h>
-#include <faabric/mpi/mpi.pb.h>
 #include <faabric/proto/faabric.pb.h>
 #include <faabric/scheduler/InMemoryMessageQueue.h>
 #include <faabric/transport/PointToPointBroker.h>
@@ -26,10 +26,9 @@ namespace faabric::mpi {
 // -----------------------------------
 // MPITOPTP - mocking at the MPI level won't be needed when using the PTP broker
 // as the broker already has mocking capabilities
-std::vector<std::shared_ptr<MPIMessage>> getMpiMockedMessages(int sendRank);
+std::vector<MpiMessage> getMpiMockedMessages(int sendRank);
 
-typedef faabric::util::FixedCapacityQueue<std::shared_ptr<MPIMessage>>
-  InMemoryMpiQueue;
+typedef faabric::util::FixedCapacityQueue<MpiMessage> InMemoryMpiQueue;
 
 class MpiWorld
 {
@@ -73,21 +72,21 @@ class MpiWorld
               const uint8_t* buffer,
               faabric_datatype_t* dataType,
               int count,
-              MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
+              MpiMessageType messageType = MpiMessageType::NORMAL);
 
     int isend(int sendRank,
               int recvRank,
               const uint8_t* buffer,
               faabric_datatype_t* dataType,
               int count,
-              MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
+              MpiMessageType messageType = MpiMessageType::NORMAL);
 
     void broadcast(int rootRank,
                    int thisRank,
                    uint8_t* buffer,
                    faabric_datatype_t* dataType,
                    int count,
-                   MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
+                   MpiMessageType messageType = MpiMessageType::NORMAL);
 
     void recv(int sendRank,
               int recvRank,
@@ -95,14 +94,14 @@ class MpiWorld
               faabric_datatype_t* dataType,
               int count,
               MPI_Status* status,
-              MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
+              MpiMessageType messageType = MpiMessageType::NORMAL);
 
     int irecv(int sendRank,
               int recvRank,
               uint8_t* buffer,
               faabric_datatype_t* dataType,
               int count,
-              MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
+              MpiMessageType messageType = MpiMessageType::NORMAL);
 
     void awaitAsyncRequest(int requestId);
 
@@ -240,29 +239,36 @@ class MpiWorld
     void sendRemoteMpiMessage(std::string dstHost,
                               int sendRank,
                               int recvRank,
-                              const std::shared_ptr<MPIMessage>& msg);
+                              const MpiMessage& msg);
 
-    std::shared_ptr<MPIMessage> recvRemoteMpiMessage(int sendRank,
-                                                     int recvRank);
+    MpiMessage recvRemoteMpiMessage(int sendRank, int recvRank);
 
     // Support for asyncrhonous communications
     std::shared_ptr<MpiMessageBuffer> getUnackedMessageBuffer(int sendRank,
                                                               int recvRank);
 
-    std::shared_ptr<MPIMessage> recvBatchReturnLast(int sendRank,
-                                                    int recvRank,
-                                                    int batchSize = 0);
+    MpiMessage recvBatchReturnLast(int sendRank,
+                                   int recvRank,
+                                   int batchSize = 0);
 
     /* Helper methods */
 
     void checkRanksRange(int sendRank, int recvRank);
 
     // Abstraction of the bulk of the recv work, shared among various functions
-    void doRecv(std::shared_ptr<MPIMessage>& m,
+    void doRecv(const MpiMessage& m,
                 uint8_t* buffer,
                 faabric_datatype_t* dataType,
                 int count,
                 MPI_Status* status,
-                MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
+                MpiMessageType messageType = MpiMessageType::NORMAL);
+
+    // Abstraction of the bulk of the recv work, shared among various functions
+    void doRecv(std::unique_ptr<MpiMessage> m,
+                uint8_t* buffer,
+                faabric_datatype_t* dataType,
+                int count,
+                MPI_Status* status,
+                MpiMessageType messageType = MpiMessageType::NORMAL);
 };
 }
diff --git a/src/mpi/CMakeLists.txt b/src/mpi/CMakeLists.txt
index dfd434c87..3ac5b98c9 100644
--- a/src/mpi/CMakeLists.txt
+++ b/src/mpi/CMakeLists.txt
@@ -38,32 +38,12 @@ endif()
 # -----------------------------------------------
 
 if (NOT ("${CMAKE_PROJECT_NAME}" STREQUAL "faabricmpi"))
-    # Generate protobuf headers
-    set(MPI_PB_HEADER_COPIED "${FAABRIC_INCLUDE_DIR}/faabric/mpi/mpi.pb.h")
-
-    protobuf_generate_cpp(MPI_PB_SRC MPI_PB_HEADER mpi.proto)
-
-    # Copy the generated headers into place
-    add_custom_command(
-        OUTPUT "${MPI_PB_HEADER_COPIED}"
-        DEPENDS "${MPI_PB_HEADER}"
-        COMMAND ${CMAKE_COMMAND}
-        ARGS -E copy ${MPI_PB_HEADER} ${FAABRIC_INCLUDE_DIR}/faabric/mpi/
-    )
-
-    add_custom_target(
-        mpi_pbh_copied
-        DEPENDS ${MPI_PB_HEADER_COPIED}
-    )
-
-    add_dependencies(faabric_common_dependencies mpi_pbh_copied)
-
     faabric_lib(mpi
         MpiContext.cpp
+        MpiMessage.cpp
         MpiMessageBuffer.cpp
         MpiWorld.cpp
         MpiWorldRegistry.cpp
-        ${MPI_PB_SRC}
     )
 
     target_link_libraries(mpi PRIVATE
diff --git a/src/mpi/MpiMessage.cpp b/src/mpi/MpiMessage.cpp
new file mode 100644
index 000000000..57ee8c85e
--- /dev/null
+++ b/src/mpi/MpiMessage.cpp
@@ -0,0 +1,36 @@
+#include <faabric/mpi/MpiMessage.h>
+#include <faabric/util/memory.h>
+
+#include <cassert>
+#include <cstdint>
+#include <cstring>
+
+namespace faabric::mpi {
+
+void parseMpiMsg(const std::vector<uint8_t>& bytes, MpiMessage* msg)
+{
+    assert(msg != nullptr);
+    assert(bytes.size() >= sizeof(MpiMessage));
+    std::memcpy(msg, bytes.data(), sizeof(MpiMessage));
+    size_t thisPayloadSize = bytes.size() - sizeof(MpiMessage);
+    assert(thisPayloadSize == payloadSize(*msg));
+
+    if (thisPayloadSize == 0) {
+        msg->buffer = nullptr;
+        return;
+    }
+
+    msg->buffer = faabric::util::malloc(thisPayloadSize);
+    std::memcpy(
+      msg->buffer, bytes.data() + sizeof(MpiMessage), thisPayloadSize);
+}
+
+void serializeMpiMsg(std::vector<uint8_t>& buffer, const MpiMessage& msg)
+{
+    std::memcpy(buffer.data(), &msg, sizeof(MpiMessage));
+    size_t payloadSz = payloadSize(msg);
+    if (payloadSz > 0 && msg.buffer != nullptr) {
+        std::memcpy(buffer.data() + sizeof(MpiMessage), msg.buffer, payloadSz);
+    }
+}
+}
diff --git a/src/mpi/MpiWorld.cpp b/src/mpi/MpiWorld.cpp
index d50344c40..5b0887c7a 100644
--- a/src/mpi/MpiWorld.cpp
+++ b/src/mpi/MpiWorld.cpp
@@ -1,6 +1,6 @@
 #include <faabric/batch-scheduler/SchedulingDecision.h>
+#include <faabric/mpi/MpiMessage.h>
 #include <faabric/mpi/MpiWorld.h>
-#include <faabric/mpi/mpi.pb.h>
 #include <faabric/planner/PlannerClient.h>
 #include <faabric/transport/macros.h>
 #include <faabric/util/ExecGraph.h>
@@ -34,10 +34,9 @@ static std::mutex mockMutex;
 
 // The identifier in this map is the sending rank. For the receiver's rank
 // we can inspect the MPIMessage object
-static std::map<int, std::vector<std::shared_ptr<MPIMessage>>>
-  mpiMockedMessages;
+static std::map<int, std::vector<MpiMessage>> mpiMockedMessages;
 
-std::vector<std::shared_ptr<MPIMessage>> getMpiMockedMessages(int sendRank)
+std::vector<MpiMessage> getMpiMockedMessages(int sendRank)
 {
     faabric::util::UniqueLock lock(mockMutex);
     return mpiMockedMessages[sendRank];
@@ -53,12 +52,12 @@ MpiWorld::MpiWorld()
 void MpiWorld::sendRemoteMpiMessage(std::string dstHost,
                                     int sendRank,
                                     int recvRank,
-                                    const std::shared_ptr<MPIMessage>& msg)
+                                    const MpiMessage& msg)
 {
-    std::string serialisedBuffer;
-    if (!msg->SerializeToString(&serialisedBuffer)) {
-        throw std::runtime_error("Error serialising message");
-    }
+    // Serialise
+    std::vector<uint8_t> serialisedBuffer(msgSize(msg));
+    serializeMpiMsg(serialisedBuffer, msg);
+
     try {
         broker.sendMessage(
           thisRankMsg->groupid(),
@@ -79,8 +78,7 @@ void MpiWorld::sendRemoteMpiMessage(std::string dstHost,
     }
 }
 
-std::shared_ptr<MPIMessage> MpiWorld::recvRemoteMpiMessage(int sendRank,
-                                                           int recvRank)
+MpiMessage MpiWorld::recvRemoteMpiMessage(int sendRank, int recvRank)
 {
     std::vector<uint8_t> msg;
     try {
@@ -95,8 +93,12 @@ std::shared_ptr<MPIMessage> MpiWorld::recvRemoteMpiMessage(int sendRank,
                      recvRank);
         throw e;
     }
-    PARSE_MSG(MPIMessage, msg.data(), msg.size());
-    return std::make_shared<MPIMessage>(parsedMsg);
+
+    // TODO(mpi-opt): make sure we minimze copies here
+    MpiMessage parsedMsg;
+    parseMpiMsg(msg, &parsedMsg);
+
+    return parsedMsg;
 }
 
 std::shared_ptr<MpiMessageBuffer> MpiWorld::getUnackedMessageBuffer(
@@ -447,7 +449,7 @@ int MpiWorld::isend(int sendRank,
                     const uint8_t* buffer,
                     faabric_datatype_t* dataType,
                     int count,
-                    MPIMessage::MPIMessageType messageType)
+                    MpiMessageType messageType)
 {
     int requestId = (int)faabric::util::generateGid();
     iSendRequests.insert(requestId);
@@ -462,7 +464,7 @@ int MpiWorld::irecv(int sendRank,
                     uint8_t* buffer,
                     faabric_datatype_t* dataType,
                     int count,
-                    MPIMessage::MPIMessageType messageType)
+                    MpiMessageType messageType)
 {
     int requestId = (int)faabric::util::generateGid();
     reqIdToRanks.try_emplace(requestId, sendRank, recvRank);
@@ -489,7 +491,7 @@ void MpiWorld::send(int sendRank,
                     const uint8_t* buffer,
                     faabric_datatype_t* dataType,
                     int count,
-                    MPIMessage::MPIMessageType messageType)
+                    MpiMessageType messageType)
 {
     // Sanity-check input parameters
     checkRanksRange(sendRank, recvRank);
@@ -506,45 +508,39 @@ void MpiWorld::send(int sendRank,
     // Generate a message ID
     int msgId = (localMsgCount + 1) % INT32_MAX;
 
-    // Create the message
-    auto m = std::make_shared<MPIMessage>();
-    m->set_id(msgId);
-    m->set_worldid(id);
-    m->set_sender(sendRank);
-    m->set_destination(recvRank);
-    m->set_type(dataType->id);
-    m->set_count(count);
-    m->set_messagetype(messageType);
-
-    // Set up message data
-    bool mustSendData = count > 0 && buffer != nullptr;
+    MpiMessage msg = { .id = msgId,
+                       .worldId = id,
+                       .sendRank = sendRank,
+                       .recvRank = recvRank,
+                       .typeSize = dataType->size,
+                       .count = count,
+                       .messageType = messageType,
+                       .buffer = (void*)buffer };
 
     // Mock the message sending in tests
     if (faabric::util::isMockMode()) {
-        mpiMockedMessages[sendRank].push_back(m);
+        mpiMockedMessages[sendRank].push_back(msg);
         return;
     }
 
     // Dispatch the message locally or globally
     if (isLocal) {
-        if (mustSendData) {
+        // Take control over the buffer data if we are gonna move it to
+        // the in-memory queues for local messaging
+        if (count > 0 && buffer != nullptr) {
             void* bufferPtr = faabric::util::malloc(count * dataType->size);
             std::memcpy(bufferPtr, buffer, count * dataType->size);
 
-            m->set_bufferptr((uint64_t)bufferPtr);
+            msg.buffer = bufferPtr;
         }
 
         SPDLOG_TRACE(
           "MPI - send {} -> {} ({})", sendRank, recvRank, messageType);
-        getLocalQueue(sendRank, recvRank)->enqueue(std::move(m));
+        getLocalQueue(sendRank, recvRank)->enqueue(msg);
     } else {
-        if (mustSendData) {
-            m->set_buffer(buffer, dataType->size * count);
-        }
-
         SPDLOG_TRACE(
           "MPI - send remote {} -> {} ({})", sendRank, recvRank, messageType);
-        sendRemoteMpiMessage(otherHost, sendRank, recvRank, m);
+        sendRemoteMpiMessage(otherHost, sendRank, recvRank, msg);
     }
 
     /* 02/05/2022 - The following bit of code fails randomly with a protobuf
@@ -572,7 +568,7 @@ void MpiWorld::recv(int sendRank,
                     faabric_datatype_t* dataType,
                     int count,
                     MPI_Status* status,
-                    MPIMessage::MPIMessageType messageType)
+                    MpiMessageType messageType)
 {
     // Sanity-check input parameters
     checkRanksRange(sendRank, recvRank);
@@ -582,54 +578,47 @@ void MpiWorld::recv(int sendRank,
         return;
     }
 
-    // Recv message from underlying transport
-    std::shared_ptr<MPIMessage> m = recvBatchReturnLast(sendRank, recvRank);
+    auto msg = recvBatchReturnLast(sendRank, recvRank);
 
-    // Do the processing
-    doRecv(m, buffer, dataType, count, status, messageType);
+    doRecv(std::move(msg), buffer, dataType, count, status, messageType);
 }
 
-void MpiWorld::doRecv(std::shared_ptr<MPIMessage>& m,
+void MpiWorld::doRecv(const MpiMessage& m,
                       uint8_t* buffer,
                       faabric_datatype_t* dataType,
                       int count,
                       MPI_Status* status,
-                      MPIMessage::MPIMessageType messageType)
+                      MpiMessageType messageType)
 {
     // Assert message integrity
     // Note - this checks won't happen in Release builds
-    if (m->messagetype() != messageType) {
+    if (m.messageType != messageType) {
         SPDLOG_ERROR("Different message types (got: {}, expected: {})",
-                     m->messagetype(),
+                     m.messageType,
                      messageType);
     }
-    assert(m->messagetype() == messageType);
-    assert(m->count() <= count);
-
-    const std::string otherHost = getHostForRank(m->destination());
-    bool isLocal =
-      getHostForRank(m->destination()) == getHostForRank(m->sender());
-
-    if (m->count() > 0) {
-        if (isLocal) {
-            // Make sure we do not overflow the recepient buffer
-            auto bytesToCopy = std::min<size_t>(m->count() * dataType->size,
-                                                count * dataType->size);
-            std::memcpy(buffer, (void*)m->bufferptr(), bytesToCopy);
-            faabric::util::free((void*)m->bufferptr());
-        } else {
-            // TODO - avoid copy here
-            std::move(m->buffer().begin(), m->buffer().end(), buffer);
-        }
+    assert(m.messageType == messageType);
+    assert(m.count <= count);
+
+    // We must copy the data into the application-provided buffer
+    if (m.count > 0 && m.buffer != nullptr) {
+        // Make sure we do not overflow the recepient buffer
+        auto bytesToCopy =
+          std::min<size_t>(m.count * dataType->size, count * dataType->size);
+        std::memcpy(buffer, m.buffer, bytesToCopy);
+
+        // This buffer has been malloc-ed either as part of a local `send`
+        // or as part of a remote `parseMpiMsg`
+        faabric::util::free((void*)m.buffer);
     }
 
     // Set status values if required
     if (status != nullptr) {
-        status->MPI_SOURCE = m->sender();
+        status->MPI_SOURCE = m.sendRank;
         status->MPI_ERROR = MPI_SUCCESS;
 
         // Take the message size here as the receive count may be larger
-        status->bytesSize = m->count() * dataType->size;
+        status->bytesSize = m.count * dataType->size;
 
         // TODO - thread through tag
         status->MPI_TAG = -1;
@@ -667,14 +656,14 @@ void MpiWorld::sendRecv(uint8_t* sendBuffer,
                        recvBuffer,
                        recvDataType,
                        recvCount,
-                       MPIMessage::SENDRECV);
+                       MpiMessageType::SENDRECV);
     // Then send the message
     send(myRank,
          sendRank,
          sendBuffer,
          sendDataType,
          sendCount,
-         MPIMessage::SENDRECV);
+         MpiMessageType::SENDRECV);
     // And wait
     awaitAsyncRequest(recvId);
 }
@@ -684,7 +673,7 @@ void MpiWorld::broadcast(int sendRank,
                          uint8_t* buffer,
                          faabric_datatype_t* dataType,
                          int count,
-                         MPIMessage::MPIMessageType messageType)
+                         MpiMessageType messageType)
 {
     SPDLOG_TRACE("MPI - bcast {} -> {}", sendRank, recvRank);
 
@@ -795,7 +784,7 @@ void MpiWorld::scatter(int sendRank,
                      startPtr,
                      sendType,
                      sendCount,
-                     MPIMessage::SCATTER);
+                     MpiMessageType::SCATTER);
             }
         }
     } else {
@@ -806,7 +795,7 @@ void MpiWorld::scatter(int sendRank,
              recvType,
              recvCount,
              nullptr,
-             MPIMessage::SCATTER);
+             MpiMessageType::SCATTER);
     }
 }
 
@@ -880,7 +869,7 @@ void MpiWorld::gather(int sendRank,
                              recvType,
                              recvCount,
                              nullptr,
-                             MPIMessage::GATHER);
+                             MpiMessageType::GATHER);
                     }
                 }
             } else {
@@ -894,7 +883,7 @@ void MpiWorld::gather(int sendRank,
                      recvType,
                      recvCount * it.second.size(),
                      nullptr,
-                     MPIMessage::GATHER);
+                     MpiMessageType::GATHER);
 
                 // Copy each received chunk to its offset
                 for (int r = 0; r < it.second.size(); r++) {
@@ -924,7 +913,7 @@ void MpiWorld::gather(int sendRank,
                      sendType,
                      sendCount,
                      nullptr,
-                     MPIMessage::GATHER);
+                     MpiMessageType::GATHER);
             }
         }
 
@@ -934,7 +923,7 @@ void MpiWorld::gather(int sendRank,
              rankData.get(),
              sendType,
              sendCount * ranksForHost[thisHost].size(),
-             MPIMessage::GATHER);
+             MpiMessageType::GATHER);
 
     } else if (isLocalLeader && isLocalGather) {
         // Scenario 3
@@ -943,7 +932,7 @@ void MpiWorld::gather(int sendRank,
              sendBuffer + sendBufferOffset,
              sendType,
              sendCount,
-             MPIMessage::GATHER);
+             MpiMessageType::GATHER);
     } else if (!isLocalLeader && !isLocalGather) {
         // Scenario 4
         send(sendRank,
@@ -951,7 +940,7 @@ void MpiWorld::gather(int sendRank,
              sendBuffer + sendBufferOffset,
              sendType,
              sendCount,
-             MPIMessage::GATHER);
+             MpiMessageType::GATHER);
     } else if (!isLocalLeader && isLocalGather) {
         // Scenario 5
         send(sendRank,
@@ -959,7 +948,7 @@ void MpiWorld::gather(int sendRank,
              sendBuffer + sendBufferOffset,
              sendType,
              sendCount,
-             MPIMessage::GATHER);
+             MpiMessageType::GATHER);
     } else {
         SPDLOG_ERROR("Don't know how to gather rank's data.");
         SPDLOG_ERROR("- sendRank: {}\n- recvRank: {}\n- isGatherReceiver: "
@@ -1001,7 +990,7 @@ void MpiWorld::allGather(int rank,
 
     // Do a broadcast with a hard-coded root
     broadcast(
-      root, rank, recvBuffer, recvType, fullCount, MPIMessage::ALLGATHER);
+      root, rank, recvBuffer, recvType, fullCount, MpiMessageType::ALLGATHER);
 }
 
 void MpiWorld::awaitAsyncRequest(int requestId)
@@ -1033,10 +1022,10 @@ void MpiWorld::awaitAsyncRequest(int requestId)
     std::list<MpiMessageBuffer::PendingAsyncMpiMessage>::iterator msgIt =
       umb->getRequestPendingMsg(requestId);
 
-    std::shared_ptr<MPIMessage> m;
+    MpiMessage m;
     if (msgIt->msg != nullptr) {
         // This id has already been acknowledged by a recv call, so do the recv
-        m = msgIt->msg;
+        m = *(msgIt->msg);
     } else {
         // We need to acknowledge all messages not acknowledged from the
         // begining until us
@@ -1094,7 +1083,7 @@ void MpiWorld::reduce(int sendRank,
                          datatype,
                          count,
                          nullptr,
-                         MPIMessage::REDUCE);
+                         MpiMessageType::REDUCE);
 
                     op_reduce(
                       operation, datatype, count, rankData.get(), recvBuffer);
@@ -1108,7 +1097,7 @@ void MpiWorld::reduce(int sendRank,
                      datatype,
                      count,
                      nullptr,
-                     MPIMessage::REDUCE);
+                     MpiMessageType::REDUCE);
 
                 op_reduce(
                   operation, datatype, count, rankData.get(), recvBuffer);
@@ -1138,7 +1127,7 @@ void MpiWorld::reduce(int sendRank,
                      datatype,
                      count,
                      nullptr,
-                     MPIMessage::REDUCE);
+                     MpiMessageType::REDUCE);
 
                 op_reduce(operation,
                           datatype,
@@ -1152,7 +1141,7 @@ void MpiWorld::reduce(int sendRank,
                  sendBufferCopy.get(),
                  datatype,
                  count,
-                 MPIMessage::REDUCE);
+                 MpiMessageType::REDUCE);
         } else {
             // Send to the receiver rank
             send(sendRank,
@@ -1160,7 +1149,7 @@ void MpiWorld::reduce(int sendRank,
                  sendBuffer,
                  datatype,
                  count,
-                 MPIMessage::REDUCE);
+                 MpiMessageType::REDUCE);
         }
     } else {
         // If we are neither the receiver of the reduce nor a local leader, we
@@ -1175,7 +1164,7 @@ void MpiWorld::reduce(int sendRank,
              sendBuffer,
              datatype,
              count,
-             MPIMessage::REDUCE);
+             MpiMessageType::REDUCE);
     }
 }
 
@@ -1191,7 +1180,7 @@ void MpiWorld::allReduce(int rank,
     reduce(rank, 0, sendBuffer, recvBuffer, datatype, count, operation);
 
     // Second, 0 broadcasts the result to all ranks
-    broadcast(0, rank, recvBuffer, datatype, count, MPIMessage::ALLREDUCE);
+    broadcast(0, rank, recvBuffer, datatype, count, MpiMessageType::ALLREDUCE);
 }
 
 void MpiWorld::op_reduce(faabric_op_t* operation,
@@ -1350,14 +1339,14 @@ void MpiWorld::scan(int rank,
              datatype,
              count,
              nullptr,
-             MPIMessage::SCAN);
+             MpiMessageType::SCAN);
         // Reduce with our own value
         op_reduce(operation, datatype, count, currentAcc.get(), recvBuffer);
     }
 
     // If not the last process, send to the next one
     if (rank < this->size - 1) {
-        send(rank, rank + 1, recvBuffer, MPI_INT, count, MPIMessage::SCAN);
+        send(rank, rank + 1, recvBuffer, MPI_INT, count, MpiMessageType::SCAN);
     }
 }
 
@@ -1385,7 +1374,12 @@ void MpiWorld::allToAll(int rank,
               sendChunk, sendChunk + sendOffset, recvBuffer + rankOffset);
         } else {
             // Send message to other rank
-            send(rank, r, sendChunk, sendType, sendCount, MPIMessage::ALLTOALL);
+            send(rank,
+                 r,
+                 sendChunk,
+                 sendType,
+                 sendCount,
+                 MpiMessageType::ALLTOALL);
         }
     }
 
@@ -1405,7 +1399,7 @@ void MpiWorld::allToAll(int rank,
              recvType,
              recvCount,
              nullptr,
-             MPIMessage::ALLTOALL);
+             MpiMessageType::ALLTOALL);
     }
 }
 
@@ -1416,15 +1410,17 @@ void MpiWorld::allToAll(int rank,
 // queues.
 void MpiWorld::probe(int sendRank, int recvRank, MPI_Status* status)
 {
+    throw std::runtime_error("Probe not implemented!");
+    /*
     const std::shared_ptr<InMemoryMpiQueue>& queue =
       getLocalQueue(sendRank, recvRank);
-    // 30/12/21 - Peek will throw a runtime error
     std::shared_ptr<MPIMessage> m = *(queue->peek());
 
     faabric_datatype_t* datatype = getFaabricDatatypeFromId(m->type());
     status->bytesSize = m->count() * datatype->size;
     status->MPI_ERROR = 0;
     status->MPI_SOURCE = m->sender();
+    */
 }
 
 void MpiWorld::barrier(int thisRank)
@@ -1437,17 +1433,17 @@ void MpiWorld::barrier(int thisRank)
         // Await messages from all others
         for (int r = 1; r < size; r++) {
             MPI_Status s{};
-            recv(r, 0, nullptr, MPI_INT, 0, &s, MPIMessage::BARRIER_JOIN);
+            recv(r, 0, nullptr, MPI_INT, 0, &s, MpiMessageType::BARRIER_JOIN);
             SPDLOG_TRACE("MPI - recv barrier join {}", s.MPI_SOURCE);
         }
     } else {
         // Tell the root that we're waiting
         SPDLOG_TRACE("MPI - barrier join {}", thisRank);
-        send(thisRank, 0, nullptr, MPI_INT, 0, MPIMessage::BARRIER_JOIN);
+        send(thisRank, 0, nullptr, MPI_INT, 0, MpiMessageType::BARRIER_JOIN);
     }
 
     // Rank 0 broadcasts that the barrier is done (the others block here)
-    broadcast(0, thisRank, nullptr, MPI_INT, 0, MPIMessage::BARRIER_DONE);
+    broadcast(0, thisRank, nullptr, MPI_INT, 0, MpiMessageType::BARRIER_DONE);
     SPDLOG_TRACE("MPI - barrier done {}", thisRank);
 }
 
@@ -1477,9 +1473,10 @@ void MpiWorld::initLocalQueues()
     }
 }
 
-std::shared_ptr<MPIMessage> MpiWorld::recvBatchReturnLast(int sendRank,
-                                                          int recvRank,
-                                                          int batchSize)
+// TODO(mpi-opt): double-check that the fast (no-async) path is fast
+MpiMessage MpiWorld::recvBatchReturnLast(int sendRank,
+                                         int recvRank,
+                                         int batchSize)
 {
     std::shared_ptr<MpiMessageBuffer> umb =
       getUnackedMessageBuffer(sendRank, recvRank);
@@ -1499,7 +1496,7 @@ std::shared_ptr<MPIMessage> MpiWorld::recvBatchReturnLast(int sendRank,
     // Recv message: first we receive all messages for which there is an id
     // in the unacknowleged buffer but no msg. Note that these messages
     // (batchSize - 1) were `irecv`-ed before ours.
-    std::shared_ptr<MPIMessage> ourMsg;
+    MpiMessage ourMsg;
     auto msgIt = umb->getFirstNullMsg();
     if (isLocal) {
         // First receive messages that happened before us
diff --git a/src/mpi/mpi.proto b/src/mpi/mpi.proto
deleted file mode 100644
index 80a690820..000000000
--- a/src/mpi/mpi.proto
+++ /dev/null
@@ -1,35 +0,0 @@
-syntax = "proto3";
-
-package faabric.mpi;
-
-message MPIMessage {
-    enum MPIMessageType {
-        NORMAL = 0;
-        BARRIER_JOIN = 1;
-        BARRIER_DONE = 2;
-        SCATTER = 3;
-        GATHER = 4;
-        ALLGATHER = 5;
-        REDUCE = 6;
-        SCAN = 7;
-        ALLREDUCE = 8;
-        ALLTOALL = 9;
-        SENDRECV = 10;
-        BROADCAST = 11;
-    };
-
-    MPIMessageType messageType = 1;
-
-    int32 id = 2;
-    int32 worldId = 3;
-    int32 sender = 4;
-    int32 destination = 5;
-    int32 type = 6;
-    int32 count = 7;
-
-    // For remote messaging
-    optional bytes buffer = 8;
-
-    // For local messaging
-    optional int64 bufferPtr = 9;
-}
diff --git a/tests/dist/mpi/mpi_native.cpp b/tests/dist/mpi/mpi_native.cpp
index d41235940..a499fb357 100644
--- a/tests/dist/mpi/mpi_native.cpp
+++ b/tests/dist/mpi/mpi_native.cpp
@@ -2,9 +2,9 @@
 
 #include <faabric/executor/ExecutorContext.h>
 #include <faabric/mpi/MpiContext.h>
+#include <faabric/mpi/MpiMessage.h>
 #include <faabric/mpi/MpiWorld.h>
 #include <faabric/mpi/mpi.h>
-#include <faabric/mpi/mpi.pb.h>
 #include <faabric/scheduler/FunctionCallClient.h>
 #include <faabric/scheduler/Scheduler.h>
 #include <faabric/snapshot/SnapshotClient.h>
@@ -126,7 +126,7 @@ int MPI_Send(const void* buf,
                              (uint8_t*)buf,
                              datatype,
                              count,
-                             MPIMessage::NORMAL);
+                             MpiMessageType::NORMAL);
 
     return MPI_SUCCESS;
 }
@@ -159,7 +159,7 @@ int MPI_Recv(void* buf,
                              datatype,
                              count,
                              status,
-                             MPIMessage::NORMAL);
+                             MpiMessageType::NORMAL);
 
     return MPI_SUCCESS;
 }
@@ -245,7 +245,7 @@ int MPI_Bcast(void* buffer,
 
     int rank = executingContext.getRank();
     world.broadcast(
-      root, rank, (uint8_t*)buffer, datatype, count, MPIMessage::BROADCAST);
+      root, rank, (uint8_t*)buffer, datatype, count, MpiMessageType::BROADCAST);
     return MPI_SUCCESS;
 }
 
diff --git a/tests/test/mpi/test_mpi_message.cpp b/tests/test/mpi/test_mpi_message.cpp
new file mode 100644
index 000000000..9c79f8d3b
--- /dev/null
+++ b/tests/test/mpi/test_mpi_message.cpp
@@ -0,0 +1,123 @@
+#include <catch2/catch.hpp>
+
+#include <faabric/mpi/MpiMessage.h>
+#include <faabric/util/memory.h>
+
+#include <cstring>
+
+using namespace faabric::mpi;
+
+namespace tests {
+
+bool areMpiMsgEqual(const MpiMessage& msgA, const MpiMessage& msgB)
+{
+    auto sizeA = msgSize(msgA);
+    auto sizeB = msgSize(msgB);
+
+    if (sizeA != sizeB) {
+        return false;
+    }
+
+    // First, compare the message body (excluding the pointer, which we
+    // know is at the end)
+    if (std::memcmp(&msgA, &msgB, sizeof(MpiMessage) - sizeof(void*)) != 0) {
+        return false;
+    }
+
+    // Check that if one buffer points to null, so must do the other
+    if (msgA.buffer == nullptr || msgB.buffer == nullptr) {
+        return msgA.buffer == msgB.buffer;
+    }
+
+    // If none points to null, they must point to the same data
+    auto payloadSizeA = payloadSize(msgA);
+    auto payloadSizeB = payloadSize(msgB);
+    // Assert, as this should pass given the previous comparisons
+    assert(payloadSizeA == payloadSizeB);
+
+    return std::memcmp(msgA.buffer, msgB.buffer, payloadSizeA) == 0;
+}
+
+TEST_CASE("Test getting a message size", "[mpi]")
+{
+    MpiMessage msg = { .id = 1,
+                       .worldId = 3,
+                       .sendRank = 3,
+                       .recvRank = 7,
+                       .typeSize = 1,
+                       .count = 3,
+                       .messageType = MpiMessageType::NORMAL };
+
+    size_t expectedMsgSize = 0;
+    size_t expectedPayloadSize = 0;
+
+    SECTION("Empty message")
+    {
+        msg.buffer = nullptr;
+        msg.count = 0;
+        expectedMsgSize = sizeof(MpiMessage);
+        expectedPayloadSize = 0;
+    }
+
+    SECTION("Non-empty message")
+    {
+        std::vector<int> nums = { 1, 2, 3, 4, 5, 6, 6 };
+        msg.count = nums.size();
+        msg.typeSize = sizeof(int);
+        msg.buffer = faabric::util::malloc(msg.count * msg.typeSize);
+        std::memcpy(msg.buffer, nums.data(), nums.size() * sizeof(int));
+
+        expectedPayloadSize = sizeof(int) * nums.size();
+        expectedMsgSize = sizeof(MpiMessage) + expectedPayloadSize;
+    }
+
+    REQUIRE(expectedMsgSize == msgSize(msg));
+    REQUIRE(expectedPayloadSize == payloadSize(msg));
+
+    if (msg.buffer != nullptr) {
+        faabric::util::free(msg.buffer);
+    }
+}
+
+TEST_CASE("Test (de)serialising an MPI message", "[mpi]")
+{
+    MpiMessage msg = { .id = 1,
+                       .worldId = 3,
+                       .sendRank = 3,
+                       .recvRank = 7,
+                       .typeSize = 1,
+                       .count = 3,
+                       .messageType = MpiMessageType::NORMAL };
+
+    SECTION("Empty message")
+    {
+        msg.count = 0;
+        msg.buffer = nullptr;
+    }
+
+    SECTION("Non-empty message")
+    {
+        std::vector<int> nums = { 1, 2, 3, 4, 5, 6, 6 };
+        msg.count = nums.size();
+        msg.typeSize = sizeof(int);
+        msg.buffer = faabric::util::malloc(msg.count * msg.typeSize);
+        std::memcpy(msg.buffer, nums.data(), nums.size() * sizeof(int));
+    }
+
+    // Serialise and de-serialise
+    std::vector<uint8_t> buffer(msgSize(msg));
+    serializeMpiMsg(buffer, msg);
+
+    MpiMessage parsedMsg;
+    parseMpiMsg(buffer, &parsedMsg);
+
+    REQUIRE(areMpiMsgEqual(msg, parsedMsg));
+
+    if (msg.buffer != nullptr) {
+        faabric::util::free(msg.buffer);
+    }
+    if (parsedMsg.buffer != nullptr) {
+        faabric::util::free(parsedMsg.buffer);
+    }
+}
+}
diff --git a/tests/test/mpi/test_mpi_message_buffer.cpp b/tests/test/mpi/test_mpi_message_buffer.cpp
index 1674172fd..710a3c259 100644
--- a/tests/test/mpi/test_mpi_message_buffer.cpp
+++ b/tests/test/mpi/test_mpi_message_buffer.cpp
@@ -21,7 +21,7 @@ MpiMessageBuffer::PendingAsyncMpiMessage genRandomArguments(
     pendingMsg.requestId = requestId;
 
     if (!nullMsg) {
-        pendingMsg.msg = std::make_shared<MPIMessage>();
+        pendingMsg.msg = std::make_shared<MpiMessage>();
     }
 
     return pendingMsg;
diff --git a/tests/test/mpi/test_mpi_world.cpp b/tests/test/mpi/test_mpi_world.cpp
index 2c0030b5f..8094cae7e 100644
--- a/tests/test/mpi/test_mpi_world.cpp
+++ b/tests/test/mpi/test_mpi_world.cpp
@@ -212,23 +212,22 @@ TEST_CASE_METHOD(MpiBaseTestFixture, "Test local barrier", "[mpi]")
     world.destroy();
 }
 
-void checkMessage(MPIMessage& actualMessage,
+void checkMessage(MpiMessage& actualMessage,
                   int worldId,
                   int senderRank,
                   int destRank,
                   const std::vector<int>& data)
 {
     // Check the message contents
-    REQUIRE(actualMessage.worldid() == worldId);
-    REQUIRE(actualMessage.count() == data.size());
-    REQUIRE(actualMessage.destination() == destRank);
-    REQUIRE(actualMessage.sender() == senderRank);
-    REQUIRE(actualMessage.type() == FAABRIC_INT);
+    REQUIRE(actualMessage.worldId == worldId);
+    REQUIRE(actualMessage.count == data.size());
+    REQUIRE(actualMessage.recvRank == destRank);
+    REQUIRE(actualMessage.sendRank == senderRank);
+    REQUIRE(actualMessage.typeSize == FAABRIC_INT);
 
     // Check data
-    const auto* rawInts =
-      reinterpret_cast<const int*>(actualMessage.buffer().c_str());
-    size_t nInts = actualMessage.buffer().size() / sizeof(int);
+    const auto* rawInts = reinterpret_cast<const int*>(actualMessage.buffer);
+    size_t nInts = payloadSize(actualMessage) / sizeof(int);
     std::vector<int> actualData(rawInts, rawInts + nInts);
     REQUIRE(actualData == data);
 }
@@ -396,10 +395,10 @@ TEST_CASE_METHOD(MpiTestFixture, "Test send/recv message with no data", "[mpi]")
     SECTION("Check on queue")
     {
         // Check message content
-        MPIMessage actualMessage =
-          *(world.getLocalQueue(rankA1, rankA2)->dequeue());
-        REQUIRE(actualMessage.count() == 0);
-        REQUIRE(actualMessage.type() == FAABRIC_INT);
+        MpiMessage actualMessage =
+          world.getLocalQueue(rankA1, rankA2)->dequeue();
+        REQUIRE(actualMessage.count == 0);
+        REQUIRE(actualMessage.typeSize == FAABRIC_INT);
     }
 
     SECTION("Check receiving with null ptr")
@@ -502,7 +501,7 @@ TEST_CASE_METHOD(MpiTestFixture, "Test collective messaging locally", "[mpi]")
                         BYTES(messageData.data()),
                         MPI_INT,
                         messageData.size(),
-                        MPIMessage::BROADCAST);
+                        MpiMessageType::BROADCAST);
 
         // Recv on all non-root ranks
         for (int rank = 0; rank < worldSize; rank++) {
@@ -515,7 +514,7 @@ TEST_CASE_METHOD(MpiTestFixture, "Test collective messaging locally", "[mpi]")
                             BYTES(actual.data()),
                             MPI_INT,
                             3,
-                            MPIMessage::BROADCAST);
+                            MpiMessageType::BROADCAST);
             REQUIRE(actual == messageData);
         }
     }
diff --git a/tests/test/mpi/test_multiple_mpi_worlds.cpp b/tests/test/mpi/test_multiple_mpi_worlds.cpp
index 735556f6e..a6e74f0b5 100644
--- a/tests/test/mpi/test_multiple_mpi_worlds.cpp
+++ b/tests/test/mpi/test_multiple_mpi_worlds.cpp
@@ -164,7 +164,7 @@ TEST_CASE_METHOD(MultiWorldMpiTestFixture,
         REQUIRE(worldA.getLocalQueueSize(rankA2, 0) == 0);
         const std::shared_ptr<InMemoryMpiQueue>& queueA2 =
           worldA.getLocalQueue(rankA1, rankA2);
-        MPIMessage actualMessage = *(queueA2->dequeue());
+        MpiMessage actualMessage = queueA2->dequeue();
         // checkMessage(actualMessage, worldId, rankA1, rankA2, messageData);
 
         // Check for world B
@@ -174,7 +174,7 @@ TEST_CASE_METHOD(MultiWorldMpiTestFixture,
         REQUIRE(worldB.getLocalQueueSize(rankA2, 0) == 0);
         const std::shared_ptr<InMemoryMpiQueue>& queueA2B =
           worldB.getLocalQueue(rankA1, rankA2);
-        actualMessage = *(queueA2B->dequeue());
+        actualMessage = queueA2B->dequeue();
         // checkMessage(actualMessage, worldId, rankA1, rankA2, messageData);
     }
 
diff --git a/tests/test/mpi/test_remote_mpi_worlds.cpp b/tests/test/mpi/test_remote_mpi_worlds.cpp
index 1e56b48b1..54662929f 100644
--- a/tests/test/mpi/test_remote_mpi_worlds.cpp
+++ b/tests/test/mpi/test_remote_mpi_worlds.cpp
@@ -21,12 +21,11 @@ using namespace faabric::mpi;
 using namespace faabric::scheduler;
 
 namespace tests {
-std::set<int> getReceiversFromMessages(
-  std::vector<std::shared_ptr<MPIMessage>> msgs)
+std::set<int> getReceiversFromMessages(std::vector<MpiMessage> msgs)
 {
     std::set<int> receivers;
     for (const auto& msg : msgs) {
-        receivers.insert(msg->destination());
+        receivers.insert(msg.recvRank);
     }
 
     return receivers;
@@ -108,14 +107,14 @@ TEST_CASE_METHOD(RemoteMpiTestFixture,
                             BYTES(messageData.data()),
                             MPI_INT,
                             messageData.size(),
-                            MPIMessage::BROADCAST);
+                            MpiMessageType::BROADCAST);
     } else {
         otherWorld.broadcast(sendRank,
                              recvRank,
                              BYTES(messageData.data()),
                              MPI_INT,
                              messageData.size(),
-                             MPIMessage::BROADCAST);
+                             MpiMessageType::BROADCAST);
     }
     auto msgs = getMpiMockedMessages(recvRank);
     REQUIRE(msgs.size() == expectedNumMsg);
@@ -219,12 +218,11 @@ TEST_CASE_METHOD(RemoteMpiTestFixture,
     thisWorld.destroy();
 }
 
-std::set<int> getMsgCountsFromMessages(
-  std::vector<std::shared_ptr<MPIMessage>> msgs)
+std::set<int> getMsgCountsFromMessages(std::vector<MpiMessage> msgs)
 {
     std::set<int> counts;
     for (const auto& msg : msgs) {
-        counts.insert(msg->count());
+        counts.insert(msg.count);
     }
 
     return counts;

From 86a8e7647807c062bd7170440aad6746151f92a1 Mon Sep 17 00:00:00 2001
From: Carlos Segarra <carlos@carlossegarra.com>
Date: Wed, 20 Mar 2024 16:23:49 +0000
Subject: [PATCH 2/2] mpi: make struct 8-byte aligned

---
 include/faabric/mpi/MpiMessage.h | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/include/faabric/mpi/MpiMessage.h b/include/faabric/mpi/MpiMessage.h
index 7c85fde48..853f3dce9 100644
--- a/include/faabric/mpi/MpiMessage.h
+++ b/include/faabric/mpi/MpiMessage.h
@@ -21,6 +21,18 @@ enum MpiMessageType : int32_t
     BROADCAST = 11,
 };
 
+/* Simple fixed-size C-struct to capture the state of an MPI message moving
+ * through Faabric.
+ *
+ * We require fixed-size, and no unique pointers to be able to use
+ * high-throughput in-memory ring-buffers to send the messages around.
+ * This also means that we manually malloc/free the data pointer. The message
+ * size is:
+ * 7 * int32_t = 7 * 4 bytes = 28 bytes
+ * 1 * int32_t (padding) = 4 bytes
+ * 1 * void* = 1 * 8 bytes = 8 bytes
+ * total = 40 bytes = 5 * 8 so the struct is 8 byte-aligned
+ */
 struct MpiMessage
 {
     int32_t id;
@@ -30,8 +42,10 @@ struct MpiMessage
     int32_t typeSize;
     int32_t count;
     MpiMessageType messageType;
+    int32_t __make_8_byte_aligned;
     void* buffer;
 };
+static_assert((sizeof(MpiMessage) % 8) == 0, "MPI message must be 8-aligned!");
 
 inline size_t payloadSize(const MpiMessage& msg)
 {