Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

util(queue): explore different in memory queue implementations for mpi #380

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions include/faabric/mpi/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,24 @@

namespace faabric::mpi {

struct MpiMessage {
int32_t id;
int32_t worldId;
int32_t sendRank;
int32_t recvRank;
int32_t type;
int32_t count;
void* buffer;
};

// -----------------------------------
// Mocking
// -----------------------------------
// MPITOPTP - mocking at the MPI level won't be needed when using the PTP broker
// as the broker already has mocking capabilities
std::vector<std::shared_ptr<MPIMessage>> getMpiMockedMessages(int sendRank);

typedef faabric::util::FixedCapacityQueue<std::shared_ptr<MPIMessage>>
InMemoryMpiQueue;
typedef faabric::util::Queue<std::unique_ptr<MpiMessage>> InMemoryMpiQueue;

class MpiWorld
{
Expand Down Expand Up @@ -264,5 +273,13 @@ class MpiWorld
int count,
MPI_Status* status,
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);

// Abstraction of the bulk of the recv work, shared among various functions
void doRecv(std::unique_ptr<MpiMessage> m,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
MPI_Status* status,
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
};
}
32 changes: 22 additions & 10 deletions include/faabric/util/queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,22 @@ class QueueTimeoutException : public faabric::util::FaabricException
};

template<typename T>
class Queue
class BaseQueue
{
virtual void enqueue(T value) = 0;

virtual T dequeue(long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS) = 0;

virtual void drain() = 0;

virtual void reset() = 0;
};

template<typename T>
class Queue : public BaseQueue<T>
{
public:
void enqueue(T value)
void enqueue(T value) override
{
UniqueLock lock(mx);

Expand All @@ -46,7 +58,7 @@ class Queue
}
}

T dequeue(long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS)
T dequeue(long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS) override
{
UniqueLock lock(mx);

Expand Down Expand Up @@ -110,7 +122,7 @@ class Queue
}
}

void drain()
void drain() override
{
UniqueLock lock(mx);

Expand All @@ -125,7 +137,7 @@ class Queue
return mq.size();
}

void reset()
void reset() override
{
UniqueLock lock(mx);

Expand All @@ -144,7 +156,7 @@ class Queue
// consumer queue
// https://github.com/cameron314/readerwriterqueue
template<typename T>
class FixedCapacityQueue
class FixedCapacityQueue : public BaseQueue<T>
{
public:
FixedCapacityQueue(int capacity)
Expand All @@ -153,7 +165,7 @@ class FixedCapacityQueue
FixedCapacityQueue()
: mq(DEFAULT_QUEUE_SIZE){};

void enqueue(T value, long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS)
void enqueue(T value, long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS) override
{
if (timeoutMs <= 0) {
SPDLOG_ERROR("Invalid queue timeout: {} <= 0", timeoutMs);
Expand All @@ -169,7 +181,7 @@ class FixedCapacityQueue

void dequeueIfPresent(T* res) { mq.try_dequeue(*res); }

T dequeue(long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS)
T dequeue(long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS) override
{
if (timeoutMs <= 0) {
SPDLOG_ERROR("Invalid queue timeout: {} <= 0", timeoutMs);
Expand All @@ -190,7 +202,7 @@ class FixedCapacityQueue
throw std::runtime_error("Peek not implemented");
}

void drain(long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS)
void drain(long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS) override
{
T value;
bool success;
Expand All @@ -204,7 +216,7 @@ class FixedCapacityQueue

long size() { return mq.size_approx(); }

void reset()
void reset() override
{
moodycamel::BlockingReaderWriterCircularBuffer<T> empty(
mq.max_capacity());
Expand Down
104 changes: 82 additions & 22 deletions src/mpi/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,33 +505,44 @@ void MpiWorld::send(int sendRank,
// Generate a message ID
int msgId = (localMsgCount + 1) % INT32_MAX;

// Create the message
auto m = std::make_shared<MPIMessage>();
m->set_id(msgId);
m->set_worldid(id);
m->set_sender(sendRank);
m->set_destination(recvRank);
m->set_type(dataType->id);
m->set_count(count);
m->set_messagetype(messageType);

// Set up message data
if (count > 0 && buffer != nullptr) {
m->set_buffer(buffer, dataType->size * count);
}

// Mock the message sending in tests
/*
if (faabric::util::isMockMode()) {
mpiMockedMessages[sendRank].push_back(m);
return;
}
*/

// Dispatch the message locally or globally
if (isLocal) {
void* bufferPtr = malloc(count * dataType->size);
std::memcpy(bufferPtr, buffer, count* dataType->size);

auto msg = std::make_unique<MpiMessage>(MpiMessage{
.id = msgId, .worldId = id, .sendRank = sendRank,
.recvRank = recvRank, .type = dataType->id, .count = count,
.buffer = bufferPtr
});

SPDLOG_TRACE(
"MPI - send {} -> {} ({})", sendRank, recvRank, messageType);
getLocalQueue(sendRank, recvRank)->enqueue(std::move(m));
getLocalQueue(sendRank, recvRank)->enqueue(std::move(msg));
} else {
// Create the message
auto m = std::make_shared<MPIMessage>();
m->set_id(msgId);
m->set_worldid(id);
m->set_sender(sendRank);
m->set_destination(recvRank);
m->set_type(dataType->id);
m->set_count(count);
m->set_messagetype(messageType);

// Set up message data
if (count > 0 && buffer != nullptr) {
m->set_buffer(buffer, dataType->size * count);
}

SPDLOG_TRACE(
"MPI - send remote {} -> {} ({})", sendRank, recvRank, messageType);
sendRemoteMpiMessage(otherHost, sendRank, recvRank, m);
Expand Down Expand Up @@ -572,11 +583,54 @@ void MpiWorld::recv(int sendRank,
return;
}

// Recv message from underlying transport
std::shared_ptr<MPIMessage> m = recvBatchReturnLast(sendRank, recvRank);
bool isLocal = getHostForRank(sendRank) == getHostForRank(recvRank);

if (isLocal) {
std::unique_ptr<MpiMessage> m = getLocalQueue(sendRank, recvRank)->dequeue();
} else {
// Recv message from underlying transport
std::shared_ptr<MPIMessage> m = recvBatchReturnLast(sendRank, recvRank);

// Do the processing
doRecv(m, buffer, dataType, count, status, messageType);
// Do the processing
doRecv(m, buffer, dataType, count, status, messageType);
}
}

void MpiWorld::doRecv(std::unique_ptr<MpiMessage> m,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
MPI_Status* status,
MPIMessage::MPIMessageType messageType)
{
// Assert message integrity
// Note - this checks won't happen in Release builds
if (m->type != messageType) {
SPDLOG_ERROR("Different message types (got: {}, expected: {})",
m->type,
messageType);
}
assert(m->type == messageType);
assert(m->count <= count);

// TODO - avoid copy here
// Copy message data
if (m->count > 0) {
std::memcpy(buffer, (void*)m->buffer, count * dataType->size);
free((void*)m->buffer);
}

// Set status values if required
if (status != nullptr) {
status->MPI_SOURCE = m->sendRank;
status->MPI_ERROR = MPI_SUCCESS;

// Take the message size here as the receive count may be larger
status->bytesSize = m->count * dataType->size;

// TODO - thread through tag
status->MPI_TAG = -1;
}
}

void MpiWorld::doRecv(std::shared_ptr<MPIMessage>& m,
Expand Down Expand Up @@ -1395,15 +1449,18 @@ void MpiWorld::allToAll(int rank,
// queues.
void MpiWorld::probe(int sendRank, int recvRank, MPI_Status* status)
{
throw std::runtime_error("Probe not implemented!");
/*
const std::shared_ptr<InMemoryMpiQueue>& queue =
getLocalQueue(sendRank, recvRank);
// 30/12/21 - Peek will throw a runtime error
std::shared_ptr<MPIMessage> m = *(queue->peek());

faabric_datatype_t* datatype = getFaabricDatatypeFromId(m->type());
// faabric_datatype_t* datatype = getFaabricDatatypeFromId(m->type());
status->bytesSize = m->count() * datatype->size;
status->MPI_ERROR = 0;
status->MPI_SOURCE = m->sender();
*/
}

void MpiWorld::barrier(int thisRank)
Expand Down Expand Up @@ -1456,6 +1513,7 @@ void MpiWorld::initLocalQueues()
}
}

// TODO: double-check that the fast (no-async) path is fast
std::shared_ptr<MPIMessage> MpiWorld::recvBatchReturnLast(int sendRank,
int recvRank,
int batchSize)
Expand All @@ -1482,14 +1540,15 @@ std::shared_ptr<MPIMessage> MpiWorld::recvBatchReturnLast(int sendRank,
auto msgIt = umb->getFirstNullMsg();
if (isLocal) {
// First receive messages that happened before us
/*
for (int i = 0; i < batchSize - 1; i++) {
try {
SPDLOG_TRACE("MPI - pending recv {} -> {}", sendRank, recvRank);
auto pendingMsg = getLocalQueue(sendRank, recvRank)->dequeue();

// Put the unacked message in the UMB
assert(!msgIt->isAcknowledged());
msgIt->acknowledge(pendingMsg);
// msgIt->acknowledge(pendingMsg);
msgIt++;
} catch (faabric::util::QueueTimeoutException& e) {
SPDLOG_ERROR(
Expand All @@ -1516,6 +1575,7 @@ std::shared_ptr<MPIMessage> MpiWorld::recvBatchReturnLast(int sendRank,
recvRank);
throw e;
}
*/
} else {
// First receive messages that happened before us
for (int i = 0; i < batchSize - 1; i++) {
Expand Down